# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------

import mmcv
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from mmcv.runner import BaseModule, force_fp32, auto_fp16
from mmdet.core import multi_apply
from mmdet.models import DETECTORS
from mmdet.models.builder import build_backbone, build_detector
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
                          merge_aug_bboxes_3d, show_result)
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector

from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from projects.mmdet3d_plugin import SPConvVoxelization
from projects.mmdet3d_plugin.models.feedforward_networks.moe import SparseMoE, load_balancing_loss_func
from projects.mmdet3d_plugin.utils import CUDATimer, Timer

from .pipeline_detector import PipelineDetector

from typing import List, Dict, Optional


@DETECTORS.register_module()
class SceneDetector(BaseModule):

    def __init__(self,
                 camera_lidar_detector=None,
                 camera_only_detector=None,
                 scene_classifier=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 **kwargs):
        super(SceneDetector, self).__init__(**kwargs)
        if camera_lidar_detector is not None:
            print('building camera_lidar_detector')
            self.camera_lidar_detector = build_detector(camera_lidar_detector)
            assert isinstance(self.camera_lidar_detector, PipelineDetector)
        if camera_only_detector is not None:
            print('building camera_only_detector')
            self.camera_only_detector = build_detector(camera_only_detector)
            assert isinstance(self.camera_only_detector, PipelineDetector)
        if scene_classifier is not None:
            print('building scene_classifier')
            self.scene_classifier = build_detector(scene_classifier)
            assert isinstance(self.scene_classifier, PipelineDetector)

    def forward(self, sample_idx):
        assert not self.training

        time_data_loading = 0.
        time_data_processing = 0.
        infer_time_dict = dict()

        scene_classifier_img_data = self.scene_classifier.load_img(sample_idx)
        scene_classifier_imgs = scene_classifier_img_data['imgs']
        scene_classifier_token = scene_classifier_img_data['token']
        time_data_loading += scene_classifier_img_data['time_data_loading']
        time_data_processing += scene_classifier_img_data['time_data_processing']

        with CUDATimer('scene_classifier') as t:
            scene_classifier_imgs = scene_classifier_imgs.unsqueeze(0).cuda()
            scene_classifier_outputs = self.scene_classifier(scene_classifier_imgs)
            lidar_logits = scene_classifier_outputs[0, -1]
            if len(lidar_logits) == 1:
                use_lidar = (lidar_logits >= 0.5).item()
            else:
                use_lidar = lidar_logits.argmax(dim=0).item()
        infer_time_dict.update(t.get_time_dict(details=True))

        if use_lidar:
            # load and process the CAM_FRONT image for CMT
            camera_lidar_detector_img_cam_front_data = self.camera_lidar_detector.load_img(sample_idx, used_cameras=['CAM_FRONT'])
            img_metas_cam_front = [camera_lidar_detector_img_cam_front_data['img_metas'][0].data]
            img_cam_front = camera_lidar_detector_img_cam_front_data['img'][0].data.cuda()
            time_data_loading += camera_lidar_detector_img_cam_front_data['time_data_loading']
            time_data_processing += camera_lidar_detector_img_cam_front_data['time_data_processing']

            assert scene_classifier_token == img_metas_cam_front[0]['sample_idx'], f'{scene_classifier_token} != {img_metas_cam_front[0]["sample_idx"]}'

            with CUDATimer('cmt_cam_front_feat_and_router') as t:
                # extract CAM_FRONT features and router
                img_cam_front_feats = self.camera_lidar_detector.detector.extract_img_feat(img_cam_front, img_metas_cam_front)
                selected_cams_mask, _ = self.camera_lidar_detector.detector.router_forward(img_cam_front, img_cam_front_feats)
            infer_time_dict.update(t.get_time_dict(details=True))

            nuscenes_all_cams = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
            selected_cams = [nuscenes_all_cams[i] for i, x in enumerate(selected_cams_mask[0].to(torch.bool)) if x]

            if len(selected_cams) >= 1:
                # load and process the router-selected images for CMT
                camera_lidar_detector_img_cam_other_data = self.camera_lidar_detector.load_img(sample_idx, used_cameras=selected_cams)
                img_metas_cam_other = [camera_lidar_detector_img_cam_other_data['img_metas'][0].data]
                img_cam_other = camera_lidar_detector_img_cam_other_data['img'][0].data.cuda()
                time_data_loading += camera_lidar_detector_img_cam_other_data['time_data_loading']
                time_data_processing += camera_lidar_detector_img_cam_other_data['time_data_processing']

                with CUDATimer('cmt_cam_other_feat') as t:
                    # extract features of router-selected images
                    img_cam_other_feats = self.camera_lidar_detector.detector.extract_img_feat(img_cam_other, img_metas_cam_other)

                    # concat image features
                    img_feats = tuple()
                    for img_cam_front_feat, img_cam_other_feat in zip(img_cam_front_feats, img_cam_other_feats):
                        img_feat = torch.concat([img_cam_front_feat, img_cam_other_feat], dim=0)
                        img_feats += (img_feat,)
                infer_time_dict.update(t.get_time_dict(details=True))
            else:
                img_feats = img_cam_front_feats

            camera_lidar_detector_pts_data = self.camera_lidar_detector.load_pts(sample_idx)
            points = [camera_lidar_detector_pts_data['points'][0].data.cuda()]
            time_data_loading += camera_lidar_detector_pts_data['time_data_loading']
            time_data_processing += camera_lidar_detector_pts_data['time_data_processing']

            with CUDATimer('cmt_lidar_feat') as t:
                # extract LiDAR features
                # note: CMT's extract_pts_feat does not use img_feats or img_metas
                pts_feats = self.camera_lidar_detector.detector.extract_pts_feat(points, None, None)
            infer_time_dict.update(t.get_time_dict(details=True))

            # construct img_metas
            img_metas = [dict()]
            img_metas[0]['sample_idx'] = img_metas_cam_front[0]['sample_idx']
            img_metas[0]['box_type_3d'] = img_metas_cam_front[0]['box_type_3d']
            if len(selected_cams) >= 1:
                img_metas[0]['filename'] = img_metas_cam_front[0]['filename'] + img_metas_cam_other[0]['filename']
                img_metas[0]['lidar2img'] = img_metas_cam_front[0]['lidar2img'] + img_metas_cam_other[0]['lidar2img']
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape'] + img_metas_cam_other[0]['pad_shape']
            else:
                img_metas[0]['filename'] = img_metas_cam_front[0]['filename']
                img_metas[0]['lidar2img'] = img_metas_cam_front[0]['lidar2img']
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape']

            # note: CMT's transformer uses `lidar2img` of all cameras to build its queries
            # `post_rot=0.5, post_tran=[0., -130.]` is used in the tesing of CMT
            img_metas_all = self.camera_lidar_detector.dataset.get_img_meta(sample_idx, post_rot=0.5, post_tran=[0., -130.])
            img_metas[0]['lidar2img_query'] = img_metas_all['lidar2img']

            with CUDATimer('cmt_transformer_and_bbox') as t:
                bbox_list = [dict() for i in range(len(img_metas))]
                bbox_pts = self.camera_lidar_detector.detector.simple_test_pts(
                    pts_feats, img_feats, img_metas, rescale=False)
                for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                    result_dict['pts_bbox'] = pts_bbox
            infer_time_dict.update(t.get_time_dict(details=True))

        else:
            # load and process the CAM_FRONT image for StreamPETR
            camera_only_detector_img_cam_front_data = self.camera_only_detector.load_img(sample_idx, used_cameras=['CAM_FRONT'])
            img_metas_cam_front = [camera_only_detector_img_cam_front_data['img_metas'][0].data]
            img_cam_front = camera_only_detector_img_cam_front_data['img'][0].data.unsqueeze(0).cuda()
            time_data_loading += camera_only_detector_img_cam_front_data['time_data_loading']
            time_data_processing += camera_only_detector_img_cam_front_data['time_data_processing']

            assert scene_classifier_token == img_metas_cam_front[0]['sample_idx'], f'{scene_classifier_token} != {img_metas_cam_front[0]["sample_idx"]}'

            with CUDATimer('streampetr_cam_front_feat_and_router') as t:
                img_cam_front_feats = self.camera_only_detector.detector.extract_img_feat(img_cam_front, 1)
                selected_cams_mask, _ = self.camera_only_detector.detector.router_forward(img_cam_front_feats)
            infer_time_dict.update(t.get_time_dict(details=True))

            nuscenes_all_cams = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
            selected_cams = [nuscenes_all_cams[i] for i, x in enumerate(selected_cams_mask[0].to(torch.bool)) if x]

            if len(selected_cams) >= 1:
                # load and process the router-selected images for StreamPETR
                camera_only_detector_img_cam_other_data = self.camera_only_detector.load_img(sample_idx, used_cameras=selected_cams)
                img_metas_cam_other = [camera_only_detector_img_cam_other_data['img_metas'][0].data]
                img_cam_other = camera_only_detector_img_cam_other_data['img'][0].data.unsqueeze(0).cuda()
                time_data_loading += camera_only_detector_img_cam_other_data['time_data_loading']
                time_data_processing += camera_only_detector_img_cam_other_data['time_data_processing']

                with CUDATimer('streampetr_cam_other_feat') as t:
                    img_cam_other_feats = self.camera_only_detector.detector.extract_img_feat(img_cam_other, 1)
                infer_time_dict.update(t.get_time_dict(details=True))

            # construct data
            data = dict()
            data['ego_pose'] = camera_only_detector_img_cam_front_data['ego_pose'][0].data.unsqueeze(0).cuda()
            data['ego_pose_inv'] = camera_only_detector_img_cam_front_data['ego_pose_inv'][0].data.unsqueeze(0).cuda()
            data['timestamp'] = camera_only_detector_img_cam_front_data['timestamp'][0].data.unsqueeze(0).cuda()
            if len(selected_cams) >= 1:
                data['img'] = torch.concat([img_cam_front, img_cam_other], dim=1)
                data['lidar2img'] = torch.concat([camera_only_detector_img_cam_front_data['lidar2img'][0].data, camera_only_detector_img_cam_other_data['lidar2img'][0].data], dim=0).unsqueeze(0).cuda()
                data['intrinsics'] = torch.concat([camera_only_detector_img_cam_front_data['intrinsics'][0].data, camera_only_detector_img_cam_other_data['intrinsics'][0].data], dim=0).unsqueeze(0).cuda()
                data['extrinsics'] = torch.concat([camera_only_detector_img_cam_front_data['extrinsics'][0].data, camera_only_detector_img_cam_other_data['extrinsics'][0].data], dim=0).unsqueeze(0).cuda()
                data['img_timestamp'] = torch.concat([camera_only_detector_img_cam_front_data['img_timestamp'][0].data, camera_only_detector_img_cam_other_data['img_timestamp'][0].data], dim=0).unsqueeze(0).cuda()
                data['img_feats'] = []
                for img_cam_front_feat, img_cam_other_feat in zip(img_cam_front_feats, img_cam_other_feats):
                    img_feat = torch.concat([img_cam_front_feat, img_cam_other_feat], dim=1)
                    data['img_feats'].append(img_feat)
            else:
                data['img'] = img_cam_front
                data['lidar2img'] = camera_only_detector_img_cam_front_data['lidar2img'][0].data.unsqueeze(0).cuda()
                data['intrinsics'] = camera_only_detector_img_cam_front_data['intrinsics'][0].data.unsqueeze(0).cuda()
                data['extrinsics'] = camera_only_detector_img_cam_front_data['extrinsics'][0].data.unsqueeze(0).cuda()
                data['img_timestamp'] = camera_only_detector_img_cam_front_data['img_timestamp'][0].data.unsqueeze(0).cuda()
                data['img_feats'] = img_cam_front_feats

            # construct img_metas
            img_metas = [dict()]
            img_metas[0]['sample_idx'] = img_metas_cam_front[0]['sample_idx']
            img_metas[0]['scene_token'] = img_metas_cam_front[0]['scene_token']
            img_metas[0]['box_type_3d'] = img_metas_cam_front[0]['box_type_3d']
            img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape'] + img_metas_cam_other[0]['pad_shape']

            with CUDATimer('streampetr_transformer_and_bbox') as t:
                # +1 is for CAM_FRONT
                self.camera_only_detector.set_num_cams_valid(int(selected_cams_mask.sum().item()) + 1)

                bbox_list = [dict() for i in range(len(img_metas))]
                bbox_pts = self.camera_only_detector.detector.simple_test_pts(img_metas, **data)

                for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                    result_dict['pts_bbox'] = pts_bbox
            infer_time_dict.update(t.get_time_dict(details=True))

        for result_dict in bbox_list:
            result_dict['metas'] = dict()
            result_dict['metas']['sample_token'] = scene_classifier_token
            result_dict['metas']['use_lidar'] = use_lidar
            result_dict['metas']['used_cameras'] = ['CAM_FRONT'] + selected_cams
            result_dict['metas']['time_data_loading'] = time_data_loading
            result_dict['metas']['time_data_processing'] = time_data_processing
            result_dict['metas']['infer_time_dict'] = infer_time_dict

        return bbox_list


@DETECTORS.register_module()
class SceneDetectorV2(BaseModule):

    def __init__(self,
                 camera_lidar_detector=None,
                 camera_only_detector=None,
                 scene_classifier=None,
                 separate_camera_stream=True,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 **kwargs):
        super(SceneDetectorV2, self).__init__(**kwargs)
        if camera_lidar_detector is not None:
            print('building camera_lidar_detector')
            self.camera_lidar_detector = build_detector(camera_lidar_detector)
            assert isinstance(self.camera_lidar_detector, PipelineDetector)
        if camera_only_detector is not None:
            print('building camera_only_detector')
            self.camera_only_detector = build_detector(camera_only_detector)
            assert isinstance(self.camera_only_detector, PipelineDetector)
        if scene_classifier is not None:
            print('building scene_classifier')
            self.scene_classifier = build_detector(scene_classifier)
            assert isinstance(self.scene_classifier, PipelineDetector)

        if separate_camera_stream:
            self.stream_camera = torch.cuda.Stream(device=0)
        else:
            self.stream_camera = torch.cuda.default_stream(device=0)

    def collect_camera_lidar_data(self, camera_lidar_data, cameras, use_lidar):
        nuscenes_id_to_name = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
        nuscenes_name_to_id = dict((name, i) for i, name in enumerate(nuscenes_id_to_name))
        camera_ids = sorted([nuscenes_name_to_id[name] for name in cameras])

        result = dict()
        result['sample_idx'] = camera_lidar_data['sample_idx']
        result['timestamp'] = camera_lidar_data['timestamp']
        result['box_type_3d'] = camera_lidar_data['box_type_3d']
        result['box_mode_3d'] = camera_lidar_data['box_mode_3d']
        result['img_norm_cfg'] = camera_lidar_data['img_norm_cfg']
        result['img_fields'] = camera_lidar_data['img_fields']
        result['bbox3d_fields'] = camera_lidar_data['bbox3d_fields']
        result['pts_mask_fields'] = camera_lidar_data['pts_mask_fields']
        result['pts_seg_fields'] = camera_lidar_data['pts_seg_fields']
        result['bbox_fields'] = camera_lidar_data['bbox_fields']
        result['mask_fields'] = camera_lidar_data['mask_fields']
        result['seg_fields'] = camera_lidar_data['seg_fields']
        result['img_shape'] = tuple(list(camera_lidar_data['img_shape'][:3]) + [len(camera_ids)])
        result['ori_shape'] = tuple(list(camera_lidar_data['ori_shape'][:3]) + [len(camera_ids)])
        result['pad_shape'] = tuple(list(camera_lidar_data['pad_shape'][:3]) + [len(camera_ids)])

        if len(cameras) > 0:
            result['img_filename'] = [camera_lidar_data['img_filename'][i] for i in camera_ids]
            result['filename'] = [camera_lidar_data['filename'][i] for i in camera_ids]
            result['img_timestamp'] = [camera_lidar_data['img_timestamp'][i] for i in camera_ids]
            result['lidar2img'] = [camera_lidar_data['lidar2img'][i] for i in camera_ids]
            result['lidar2cam'] = [camera_lidar_data['lidar2cam'][i] for i in camera_ids]
            result['cam_intrinsic'] = [camera_lidar_data['cam_intrinsic'][i] for i in camera_ids]
            result['img'] = [camera_lidar_data['img'][i] for i in camera_ids]

        if use_lidar:
            result['pts_filename'] = camera_lidar_data['pts_filename']
            result['sweeps'] = camera_lidar_data['sweeps']
            result['points'] = camera_lidar_data['points']
        else:
            result['pts_filename'] = None
            result['sweeps'] = []

        return result

    def collect_camera_only_data(self, camera_only_data, cameras):
        nuscenes_id_to_name = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
        nuscenes_name_to_id = dict((name, i) for i, name in enumerate(nuscenes_id_to_name))
        camera_ids = sorted([nuscenes_name_to_id[name] for name in cameras])

        result = dict()
        result['sample_idx'] = camera_only_data['sample_idx']
        result['prev_idx'] = camera_only_data['prev_idx']
        result['scene_token'] = camera_only_data['scene_token']
        result['frame_idx'] = camera_only_data['frame_idx']
        result['prev_exists'] = camera_only_data['prev_exists']
        result['ego_pose'] = camera_only_data['ego_pose']
        result['ego_pose_inv'] = camera_only_data['ego_pose_inv']
        result['timestamp'] = camera_only_data['timestamp']
        result['box_type_3d'] = camera_only_data['box_type_3d']
        result['box_mode_3d'] = camera_only_data['box_mode_3d']
        result['img_norm_cfg'] = camera_only_data['img_norm_cfg']
        result['img_fields'] = camera_only_data['img_fields']
        result['bbox3d_fields'] = camera_only_data['bbox3d_fields']
        result['pts_mask_fields'] = camera_only_data['pts_mask_fields']
        result['pts_seg_fields'] = camera_only_data['pts_seg_fields']
        result['bbox_fields'] = camera_only_data['bbox_fields']
        result['mask_fields'] = camera_only_data['mask_fields']
        result['seg_fields'] = camera_only_data['seg_fields']
        result['img_shape'] = tuple(list(camera_only_data['img_shape'][:3]) + [len(camera_ids)])
        result['ori_shape'] = tuple(list(camera_only_data['ori_shape'][:3]) + [len(camera_ids)])
        result['pad_shape'] = tuple(list(camera_only_data['pad_shape'][:3]) + [len(camera_ids)])

        if len(cameras) > 0:
            result['img_filename'] = [camera_only_data['img_filename'][i] for i in camera_ids]
            result['filename'] = [camera_only_data['filename'][i] for i in camera_ids]
            result['img_timestamp'] = [camera_only_data['img_timestamp'][i] for i in camera_ids]
            result['lidar2img'] = [camera_only_data['lidar2img'][i] for i in camera_ids]
            result['intrinsics'] = [camera_only_data['intrinsics'][i] for i in camera_ids]
            result['extrinsics'] = [camera_only_data['extrinsics'][i] for i in camera_ids]
            result['img'] = [camera_only_data['img'][i] for i in camera_ids]
        return result

    def get_img2lidar_query(self, cam_intrinsic, lidar2cam, post_rot=0.5, post_tran=[0., -130.]):
        post_rot = torch.eye(2) * post_rot
        post_tran = torch.tensor(post_tran, dtype=torch.float32)

        lidar2img_query = []
        for i in range(len(lidar2cam)):
            cam_intrinsic[i][:2, :3] = post_rot @ cam_intrinsic[i][:2, :3]
            cam_intrinsic[i][:2, 2] = post_tran + cam_intrinsic[i][:2, 2]
            lidar2img_query.append(cam_intrinsic[i] @ lidar2cam[i])
        return lidar2img_query

    def forward(self, camera_lidar_data, camera_only_data, scene_classifier_data):
        assert not self.training

        # time_data_loading = 0.
        # time_data_processing = 0.
        # infer_time_dict = dict()

        scene_classifier_imgs = scene_classifier_data[0]
        scene_classifier_token = scene_classifier_data[1]
        # time_data_loading += scene_classifier_data[2]['time_data_loading']
        # time_data_processing += scene_classifier_data[2]['time_data_processing']

        scene_classifier_imgs = scene_classifier_imgs.unsqueeze(0).cuda(non_blocking=True)
        scene_classifier_outputs = self.scene_classifier(scene_classifier_imgs)
        lidar_logits = scene_classifier_outputs[0, -1]
        if len(lidar_logits) == 1:
            use_lidar = (lidar_logits >= 0.5).item()
        else:
            use_lidar = lidar_logits.argmax(dim=0).item()

        if use_lidar:
            with torch.cuda.stream(self.stream_camera):
                # load and process the CAM_FRONT image for CMT
                camera_lidar_detector_img_cam_front_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=['CAM_FRONT'], use_lidar=False)
                camera_lidar_detector_img_cam_front_data = self.camera_lidar_detector.img_pipeline(camera_lidar_detector_img_cam_front_data)
                img_metas_cam_front = [camera_lidar_detector_img_cam_front_data['img_metas'][0].data]
                img_cam_front = camera_lidar_detector_img_cam_front_data['img'][0].data.cuda(non_blocking=True)
            



            with torch.cuda.stream(self.stream_camera):

                # extract CAM_FRONT features and router
                img_cam_front_feats = self.camera_lidar_detector.detector.extract_img_feat(img_cam_front, img_metas_cam_front)
                selected_cams_mask, _ = self.camera_lidar_detector.detector.router_forward(img_cam_front, img_cam_front_feats)

            with torch.cuda.stream(self.stream_camera):
                nuscenes_all_cams = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
                selected_cams = [nuscenes_all_cams[i] for i, x in enumerate(selected_cams_mask[0].to(torch.bool)) if x] # GPU-CPU synchronize here

            if len(selected_cams) >= 1:
                # load and process the router-selected images for CMT
                with torch.cuda.stream(self.stream_camera):
                    camera_lidar_detector_img_cam_other_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=selected_cams, use_lidar=False)
                    camera_lidar_detector_img_cam_other_data = self.camera_lidar_detector.img_pipeline(camera_lidar_detector_img_cam_other_data)
                    # camera_lidar_detector_img_cam_other_data = self.camera_lidar_detector.load_img(sample_idx, used_cameras=selected_cams)
                    img_metas_cam_other = [camera_lidar_detector_img_cam_other_data['img_metas'][0].data]
                    img_cam_other = camera_lidar_detector_img_cam_other_data['img'][0].data.cuda(non_blocking=True)

                with torch.cuda.stream(self.stream_camera):
                    # extract features of router-selected images
                    img_cam_other_feats = self.camera_lidar_detector.detector.extract_img_feat(img_cam_other, img_metas_cam_other)

                    # concat image features
                    img_feats = tuple()
                    for img_cam_front_feat, img_cam_other_feat in zip(img_cam_front_feats, img_cam_other_feats):
                        img_feat = torch.concat([img_cam_front_feat, img_cam_other_feat], dim=0)
                        img_feats += (img_feat,)

                with torch.cuda.stream(torch.cuda.default_stream(0)):
                    # use default cuda stream for lidar because spconv throws a error when using custom stream
                    camera_lidar_detector_pts_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=[], use_lidar=True)
                    camera_lidar_detector_pts_data = self.camera_lidar_detector.pts_pipeline(camera_lidar_detector_pts_data)
                    points = [camera_lidar_detector_pts_data['points'][0].data.cuda(non_blocking=True)]
                    
                    # extract LiDAR features
                    # note: CMT's extract_pts_feat does not use img_feats or img_metas
                    pts_feats = self.camera_lidar_detector.detector.extract_pts_feat(points, None, None)

            else:
                img_feats = img_cam_front_feats
                with torch.cuda.stream(torch.cuda.default_stream(0)):
                    # use default cuda stream for lidar because spconv throws a error when using custom stream
                    camera_lidar_detector_pts_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=[], use_lidar=True)
                    camera_lidar_detector_pts_data = self.camera_lidar_detector.pts_pipeline(camera_lidar_detector_pts_data)
                    points = [camera_lidar_detector_pts_data['points'][0].data.cuda(non_blocking=True)]
                    
                    # extract LiDAR features
                    # note: CMT's extract_pts_feat does not use img_feats or img_metas
                    pts_feats = self.camera_lidar_detector.detector.extract_pts_feat(points, None, None)
            # construct img_metas
            img_metas = [dict()]
            img_metas[0]['sample_idx'] = img_metas_cam_front[0]['sample_idx']
            img_metas[0]['box_type_3d'] = img_metas_cam_front[0]['box_type_3d']
            if len(selected_cams) >= 1:
                img_metas[0]['filename'] = img_metas_cam_front[0]['filename'] + img_metas_cam_other[0]['filename']
                img_metas[0]['lidar2img'] = img_metas_cam_front[0]['lidar2img'] + img_metas_cam_other[0]['lidar2img']
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape'] + img_metas_cam_other[0]['pad_shape']
            else:
                img_metas[0]['filename'] = img_metas_cam_front[0]['filename']
                img_metas[0]['lidar2img'] = img_metas_cam_front[0]['lidar2img']
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape']

            # note: CMT's transformer uses `lidar2img` of all cameras to build its queries
            img_metas[0]['lidar2img_query'] = self.get_img2lidar_query(camera_lidar_data['cam_intrinsic'], camera_lidar_data['lidar2cam'])

            # synchronize before transformer
            self.stream_camera.synchronize()

            # transformer and bbox decoder
            bbox_list = [dict() for i in range(len(img_metas))]
            bbox_pts = self.camera_lidar_detector.detector.simple_test_pts(
                pts_feats, img_feats, img_metas, rescale=False)
            for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                result_dict['pts_bbox'] = pts_bbox

            assert scene_classifier_token == img_metas_cam_front[0]['sample_idx'], f'{scene_classifier_token} != {img_metas_cam_front[0]["sample_idx"]}'

        else:
            # load and process the CAM_FRONT image for StreamPETR
            camera_only_detector_img_cam_front_data = self.collect_camera_only_data(camera_only_data, cameras=['CAM_FRONT'])
            camera_only_detector_img_cam_front_data = self.camera_only_detector.img_pipeline(camera_only_detector_img_cam_front_data)
            # camera_only_detector_img_cam_front_data = self.camera_only_detector.load_img(sample_idx, used_cameras=['CAM_FRONT'])
            img_metas_cam_front = [camera_only_detector_img_cam_front_data['img_metas'][0].data]
            img_cam_front = camera_only_detector_img_cam_front_data['img'][0].data.unsqueeze(0).cuda(non_blocking=True)

            assert scene_classifier_token == img_metas_cam_front[0]['sample_idx'], f'{scene_classifier_token} != {img_metas_cam_front[0]["sample_idx"]}'

            img_cam_front_feats = self.camera_only_detector.detector.extract_img_feat(img_cam_front, 1)
            selected_cams_mask, _ = self.camera_only_detector.detector.router_forward(img_cam_front_feats)

            nuscenes_all_cams = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
            selected_cams = [nuscenes_all_cams[i] for i, x in enumerate(selected_cams_mask[0].to(torch.bool)) if x]

            if len(selected_cams) >= 1:
                # load and process the router-selected images for StreamPETR
                camera_only_detector_img_cam_other_data = self.collect_camera_only_data(camera_only_data, cameras=selected_cams)
                camera_only_detector_img_cam_other_data = self.camera_only_detector.img_pipeline(camera_only_detector_img_cam_other_data)
                img_metas_cam_other = [camera_only_detector_img_cam_other_data['img_metas'][0].data]
                img_cam_other = camera_only_detector_img_cam_other_data['img'][0].data.unsqueeze(0).cuda(non_blocking=True)

                img_cam_other_feats = self.camera_only_detector.detector.extract_img_feat(img_cam_other, 1)

            # construct data
            data = dict()
            data['ego_pose'] = camera_only_detector_img_cam_front_data['ego_pose'][0].data.unsqueeze(0).cuda(non_blocking=True)
            data['ego_pose_inv'] = camera_only_detector_img_cam_front_data['ego_pose_inv'][0].data.unsqueeze(0).cuda(non_blocking=True)
            data['timestamp'] = camera_only_detector_img_cam_front_data['timestamp'][0].data.unsqueeze(0).cuda(non_blocking=True)
            if len(selected_cams) >= 1:
                data['img'] = torch.concat([img_cam_front, img_cam_other], dim=1)
                data['lidar2img'] = torch.concat([camera_only_detector_img_cam_front_data['lidar2img'][0].data, camera_only_detector_img_cam_other_data['lidar2img'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['intrinsics'] = torch.concat([camera_only_detector_img_cam_front_data['intrinsics'][0].data, camera_only_detector_img_cam_other_data['intrinsics'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['extrinsics'] = torch.concat([camera_only_detector_img_cam_front_data['extrinsics'][0].data, camera_only_detector_img_cam_other_data['extrinsics'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['img_timestamp'] = torch.concat([camera_only_detector_img_cam_front_data['img_timestamp'][0].data, camera_only_detector_img_cam_other_data['img_timestamp'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['img_feats'] = []
                for img_cam_front_feat, img_cam_other_feat in zip(img_cam_front_feats, img_cam_other_feats):
                    img_feat = torch.concat([img_cam_front_feat, img_cam_other_feat], dim=1)
                    data['img_feats'].append(img_feat)
            else:
                data['img'] = img_cam_front
                data['lidar2img'] = camera_only_detector_img_cam_front_data['lidar2img'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['intrinsics'] = camera_only_detector_img_cam_front_data['intrinsics'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['extrinsics'] = camera_only_detector_img_cam_front_data['extrinsics'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['img_timestamp'] = camera_only_detector_img_cam_front_data['img_timestamp'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['img_feats'] = img_cam_front_feats

            # construct img_metas
            img_metas = [dict()]
            img_metas[0]['sample_idx'] = img_metas_cam_front[0]['sample_idx']
            img_metas[0]['scene_token'] = img_metas_cam_front[0]['scene_token']
            img_metas[0]['box_type_3d'] = img_metas_cam_front[0]['box_type_3d']
            img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape'] + img_metas_cam_other[0]['pad_shape']

            # +1 is for CAM_FRONT
            self.camera_only_detector.set_num_cams_valid(int(selected_cams_mask.sum().item()) + 1)
            bbox_list = [dict() for i in range(len(img_metas))]
            bbox_pts = self.camera_only_detector.detector.simple_test_pts(img_metas, **data)

            for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                result_dict['pts_bbox'] = pts_bbox

        for result_dict in bbox_list:
            result_dict['metas'] = dict()
            result_dict['metas']['sample_token'] = scene_classifier_token
            result_dict['metas']['use_lidar'] = use_lidar
            result_dict['metas']['used_cameras'] = ['CAM_FRONT'] + selected_cams
            result_dict['metas']['time_data_loading'] = None
            result_dict['metas']['time_data_processing'] = None
            result_dict['metas']['infer_time_dict'] = None

        return bbox_list


class SensorStatus:
    def __init__(self, min_duration: float, boot_time: float, status: str = 'off', reset_min_duration: bool = False) -> None:
        assert status in ['activated', 'off']
        self._status = status
        self.min_duration = min_duration
        self.boot_time = boot_time
        self.reset_min_duration = reset_min_duration
        self._start_time = None
        self._record_start_time = None # used when init status is 'activated', lower priority than `self._start_time`
        self._running_time_list = []

    def activate(self):
        self.update()
        if self._status == 'off':
            self._status = 'activating'
            self._start_time = time.perf_counter()
        elif self.reset_min_duration and self._status == 'activated':
            # simulating the sensor is just activated,
            # so that the running time is closed to zero
            self._start_time = time.perf_counter() - self.boot_time

    def deactivate(self):
        self.update()
        if self._status == 'activated':
            if self._start_time is None:
                assert self._record_start_time is not None, f'Please run `start_recording` before `deactivate` if you set init status to \'activated\''
                running_time = time.perf_counter() - self._record_start_time
                if running_time >= self.min_duration:
                    self._running_time_list.append(running_time)
                    self._status = 'off'
            else:
                running_time = time.perf_counter() - self._start_time - self.boot_time
                # if running_time < self.min_duration, reject to deactivate
                if running_time >= self.min_duration:
                    self._status = 'off'
                    self._running_time_list.append(running_time)

    def wait_for_booting(self):
        if self.is_activating():
            waiting_time = self.get_waiting_time()
            time.sleep(waiting_time)
            self.update()

    def update(self):
        if self._status == 'off':
            self._start_time = None
        elif self._status == 'activating':
            if time.perf_counter() - self._start_time >= self.boot_time:
                self._status = 'activated'

    def is_activated(self) -> bool:
        self.update()
        return self._status == 'activated'

    def is_activating(self) -> bool:
        self.update()
        return self._status == 'activating'

    def get_waiting_time(self) -> float:
        if self._start_time is None:
            return 0.

        elapsed_time = time.perf_counter() - self._start_time
        if elapsed_time > self.boot_time:
            return 0.
        else:
            return self.boot_time - elapsed_time

    def get_total_running_time(self) -> float:
        total_running_time = sum(self._running_time_list)

        current_active_time = 0.0
        self.update()
        if self._status == 'activated':
            if self._start_time is not None:
                current_active_time = time.perf_counter() - self._start_time - self.boot_time
                current_active_time = max(0, current_active_time)
            elif self._record_start_time is not None:
                current_active_time = time.perf_counter() - self._record_start_time
        total_running_time += current_active_time

        return total_running_time

    def start_recording(self):
        if self._record_start_time is None:
            self._record_start_time = time.perf_counter()

class SensorStatusGroup:
    def __init__(self, names: List[str], min_duration: float, boot_time: float, status: str = 'off', reset_min_duration: bool = False) -> None:
        self.names = names
        self._sensor_status_dict =  dict([(name, SensorStatus(min_duration, boot_time, status, reset_min_duration)) for name in names])

    def __getitem__(self, name: str) -> SensorStatus:
        return self._sensor_status_dict[name]

    def update(self):
        for sensor_status in self._sensor_status_dict.values():
            sensor_status.update()

    def wait_for_booting(self, names: Optional[List[str]] = None):
        if names is None:
            names = self.names
        names_set = set(names)
    
        max_waiting_time = 0.
        for name, sensor_status in self._sensor_status_dict.items():
            if name in names_set and sensor_status.is_activating():
                waiting_time = sensor_status.get_waiting_time()
                max_waiting_time = max(max_waiting_time, waiting_time)
        time.sleep(max_waiting_time)

    def start_recording(self):
        for sensor_status in self._sensor_status_dict.values():
            sensor_status.start_recording()

    def get_total_running_time_dict(self) -> Dict[str, float]:
        result = dict()
        for name, sensor_status in self._sensor_status_dict.items():
            result[name] = sensor_status.get_total_running_time()
        return result


@DETECTORS.register_module()
class SceneDetectorV3(BaseModule):

    def __init__(self,
                 camera_lidar_detector=None,
                 camera_only_detector=None,
                 scene_classifier=None,
                 separate_camera_stream=True,
                 sensor_default_status='off',
                 camera_min_duration=None,
                 camera_boot_time=None,
                 lidar_min_duration=None,
                 lidar_boot_time=None,
                 reset_min_duration=False,
                 wait_for_booting=False,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 **kwargs):
        super(SceneDetectorV3, self).__init__(**kwargs)
        if camera_lidar_detector is not None:
            print('building camera_lidar_detector')
            self.camera_lidar_detector = build_detector(camera_lidar_detector)
            assert isinstance(self.camera_lidar_detector, PipelineDetector)
        if camera_only_detector is not None:
            print('building camera_only_detector')
            self.camera_only_detector = build_detector(camera_only_detector)
            assert isinstance(self.camera_only_detector, PipelineDetector)
        if scene_classifier is not None:
            print('building scene_classifier')
            self.scene_classifier = build_detector(scene_classifier)
            assert isinstance(self.scene_classifier, PipelineDetector)

        assert camera_min_duration is not None
        assert camera_boot_time is not None
        assert lidar_min_duration is not None
        assert lidar_boot_time is not None

        nuscenes_cam_names = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'] # CAM_FRONT is always activated
        # self.camera_sensor_status = dict([(name, SensorStatus(camera_min_duration, camera_boot_time, sensor_default_status)) for name in nuscenes_cam_names])
        self.camera_sensor_status_group = SensorStatusGroup(nuscenes_cam_names, camera_min_duration, camera_boot_time, sensor_default_status, reset_min_duration)
        self.lidar_sensor_status = SensorStatus(lidar_min_duration, lidar_boot_time, sensor_default_status, reset_min_duration)
        self.wait_for_booting = wait_for_booting

        if separate_camera_stream:
            self.stream_camera = torch.cuda.Stream(device=0)
        else:
            self.stream_camera = torch.cuda.default_stream(device=0)

    def collect_camera_lidar_data(self, camera_lidar_data, cameras, use_lidar):
        nuscenes_id_to_name = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
        nuscenes_name_to_id = dict((name, i) for i, name in enumerate(nuscenes_id_to_name))
        camera_ids = sorted([nuscenes_name_to_id[name] for name in cameras])

        result = dict()
        result['sample_idx'] = camera_lidar_data['sample_idx']
        result['timestamp'] = camera_lidar_data['timestamp']
        result['box_type_3d'] = camera_lidar_data['box_type_3d']
        result['box_mode_3d'] = camera_lidar_data['box_mode_3d']
        result['img_norm_cfg'] = camera_lidar_data['img_norm_cfg']
        result['img_fields'] = camera_lidar_data['img_fields']
        result['bbox3d_fields'] = camera_lidar_data['bbox3d_fields']
        result['pts_mask_fields'] = camera_lidar_data['pts_mask_fields']
        result['pts_seg_fields'] = camera_lidar_data['pts_seg_fields']
        result['bbox_fields'] = camera_lidar_data['bbox_fields']
        result['mask_fields'] = camera_lidar_data['mask_fields']
        result['seg_fields'] = camera_lidar_data['seg_fields']
        result['img_shape'] = tuple(list(camera_lidar_data['img_shape'][:3]) + [len(camera_ids)])
        result['ori_shape'] = tuple(list(camera_lidar_data['ori_shape'][:3]) + [len(camera_ids)])
        result['pad_shape'] = tuple(list(camera_lidar_data['pad_shape'][:3]) + [len(camera_ids)])

        if len(cameras) > 0:
            result['img_filename'] = [camera_lidar_data['img_filename'][i] for i in camera_ids]
            result['filename'] = [camera_lidar_data['filename'][i] for i in camera_ids]
            result['img_timestamp'] = [camera_lidar_data['img_timestamp'][i] for i in camera_ids]
            result['lidar2img'] = [camera_lidar_data['lidar2img'][i] for i in camera_ids]
            result['lidar2cam'] = [camera_lidar_data['lidar2cam'][i] for i in camera_ids]
            result['cam_intrinsic'] = [camera_lidar_data['cam_intrinsic'][i] for i in camera_ids]
            result['img'] = [camera_lidar_data['img'][i] for i in camera_ids]

        if use_lidar:
            result['pts_filename'] = camera_lidar_data['pts_filename']
            result['sweeps'] = camera_lidar_data['sweeps']
            result['points'] = camera_lidar_data['points']
        else:
            result['pts_filename'] = None
            result['sweeps'] = []

        return result

    def collect_camera_only_data(self, camera_only_data, cameras):
        nuscenes_id_to_name = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
        nuscenes_name_to_id = dict((name, i) for i, name in enumerate(nuscenes_id_to_name))
        camera_ids = sorted([nuscenes_name_to_id[name] for name in cameras])

        result = dict()
        result['sample_idx'] = camera_only_data['sample_idx']
        result['prev_idx'] = camera_only_data['prev_idx']
        result['scene_token'] = camera_only_data['scene_token']
        result['frame_idx'] = camera_only_data['frame_idx']
        result['prev_exists'] = camera_only_data['prev_exists']
        result['ego_pose'] = camera_only_data['ego_pose']
        result['ego_pose_inv'] = camera_only_data['ego_pose_inv']
        result['timestamp'] = camera_only_data['timestamp']
        result['box_type_3d'] = camera_only_data['box_type_3d']
        result['box_mode_3d'] = camera_only_data['box_mode_3d']
        result['img_norm_cfg'] = camera_only_data['img_norm_cfg']
        result['img_fields'] = camera_only_data['img_fields']
        result['bbox3d_fields'] = camera_only_data['bbox3d_fields']
        result['pts_mask_fields'] = camera_only_data['pts_mask_fields']
        result['pts_seg_fields'] = camera_only_data['pts_seg_fields']
        result['bbox_fields'] = camera_only_data['bbox_fields']
        result['mask_fields'] = camera_only_data['mask_fields']
        result['seg_fields'] = camera_only_data['seg_fields']
        result['img_shape'] = tuple(list(camera_only_data['img_shape'][:3]) + [len(camera_ids)])
        result['ori_shape'] = tuple(list(camera_only_data['ori_shape'][:3]) + [len(camera_ids)])
        result['pad_shape'] = tuple(list(camera_only_data['pad_shape'][:3]) + [len(camera_ids)])

        if len(cameras) > 0:
            result['img_filename'] = [camera_only_data['img_filename'][i] for i in camera_ids]
            result['filename'] = [camera_only_data['filename'][i] for i in camera_ids]
            result['img_timestamp'] = [camera_only_data['img_timestamp'][i] for i in camera_ids]
            result['lidar2img'] = [camera_only_data['lidar2img'][i] for i in camera_ids]
            result['intrinsics'] = [camera_only_data['intrinsics'][i] for i in camera_ids]
            result['extrinsics'] = [camera_only_data['extrinsics'][i] for i in camera_ids]
            result['img'] = [camera_only_data['img'][i] for i in camera_ids]
        return result

    def get_img2lidar_query(self, cam_intrinsic, lidar2cam, post_rot=0.5, post_tran=[0., -130.]):
        post_rot = torch.eye(2) * post_rot
        post_tran = torch.tensor(post_tran, dtype=torch.float32)

        lidar2img_query = []
        for i in range(len(lidar2cam)):
            cam_intrinsic[i][:2, :3] = post_rot @ cam_intrinsic[i][:2, :3]
            cam_intrinsic[i][:2, 2] = post_tran + cam_intrinsic[i][:2, 2]
            lidar2img_query.append(cam_intrinsic[i] @ lidar2cam[i])
        return lidar2img_query

    def start_sensor_recording(self):
        self.lidar_sensor_status.start_recording()
        self.camera_sensor_status_group.start_recording()

    def get_total_sensor_running_time_dict(self):
        result = dict()
        result['LIDAR_TOP'] = self.lidar_sensor_status.get_total_running_time()
        result.update(self.camera_sensor_status_group.get_total_running_time_dict())
        return result

    def forward(self, camera_lidar_data, camera_only_data, scene_classifier_data):
        assert not self.training

        scene_classifier_imgs = scene_classifier_data[0]
        scene_classifier_token = scene_classifier_data[1]

        scene_classifier_imgs = scene_classifier_imgs.unsqueeze(0).cuda(non_blocking=True)
        scene_classifier_outputs = self.scene_classifier(scene_classifier_imgs)
        lidar_logits = scene_classifier_outputs[0, -1]
        if len(lidar_logits) == 1:
            use_lidar = (lidar_logits >= 0.5).item()
        else:
            use_lidar = lidar_logits.argmax(dim=0).item()

        # check lidar
        if use_lidar:
            if not self.lidar_sensor_status.is_activated():
                self.lidar_sensor_status.activate()
                if self.wait_for_booting:
                    self.lidar_sensor_status.wait_for_booting()
                else:
                    use_lidar = 0
        else:
            self.lidar_sensor_status.deactivate()

        if use_lidar:
            with torch.cuda.stream(self.stream_camera):
                # load and process the CAM_FRONT image for CMT
                camera_lidar_detector_img_cam_front_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=['CAM_FRONT'], use_lidar=False)
                camera_lidar_detector_img_cam_front_data = self.camera_lidar_detector.img_pipeline(camera_lidar_detector_img_cam_front_data)
                img_metas_cam_front = [camera_lidar_detector_img_cam_front_data['img_metas'][0].data]
                img_cam_front = camera_lidar_detector_img_cam_front_data['img'][0].data.cuda(non_blocking=True)

            with torch.cuda.stream(self.stream_camera):
                # extract CAM_FRONT features and router
                img_cam_front_feats = self.camera_lidar_detector.detector.extract_img_feat(img_cam_front, img_metas_cam_front)
                selected_cams_mask, _ = self.camera_lidar_detector.detector.router_forward(img_cam_front, img_cam_front_feats)

            with torch.cuda.stream(self.stream_camera):
                nuscenes_all_cams = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
                selected_cams = [nuscenes_all_cams[i] for i, x in enumerate(selected_cams_mask[0].to(torch.bool)) if x] # GPU-CPU synchronize here
                for cam_name in nuscenes_all_cams:
                    if cam_name in selected_cams:
                        self.camera_sensor_status_group[cam_name].activate()
                    else:
                        self.camera_sensor_status_group[cam_name].deactivate()
                    if self.wait_for_booting:
                        self.camera_sensor_status_group.wait_for_booting(selected_cams)
                selected_cams = list(filter(lambda x: self.camera_sensor_status_group[x].is_activated(), selected_cams))

            if len(selected_cams) >= 1:
                # load and process the router-selected images for CMT
                with torch.cuda.stream(self.stream_camera):
                    camera_lidar_detector_img_cam_other_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=selected_cams, use_lidar=False)
                    camera_lidar_detector_img_cam_other_data = self.camera_lidar_detector.img_pipeline(camera_lidar_detector_img_cam_other_data)
                    # camera_lidar_detector_img_cam_other_data = self.camera_lidar_detector.load_img(sample_idx, used_cameras=selected_cams)
                    img_metas_cam_other = [camera_lidar_detector_img_cam_other_data['img_metas'][0].data]
                    img_cam_other = camera_lidar_detector_img_cam_other_data['img'][0].data.cuda(non_blocking=True)

                with torch.cuda.stream(self.stream_camera):
                    # extract features of router-selected images
                    img_cam_other_feats = self.camera_lidar_detector.detector.extract_img_feat(img_cam_other, img_metas_cam_other)

                    # concat image features
                    img_feats = tuple()
                    for img_cam_front_feat, img_cam_other_feat in zip(img_cam_front_feats, img_cam_other_feats):
                        img_feat = torch.concat([img_cam_front_feat, img_cam_other_feat], dim=0)
                        img_feats += (img_feat,)

                with torch.cuda.stream(torch.cuda.default_stream(0)):
                    # use default cuda stream for lidar because spconv throws a error when using custom stream
                    camera_lidar_detector_pts_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=[], use_lidar=True)
                    camera_lidar_detector_pts_data = self.camera_lidar_detector.pts_pipeline(camera_lidar_detector_pts_data)
                    points = [camera_lidar_detector_pts_data['points'][0].data.cuda(non_blocking=True)]
                    
                    # extract LiDAR features
                    # note: CMT's extract_pts_feat does not use img_feats or img_metas
                    pts_feats = self.camera_lidar_detector.detector.extract_pts_feat(points, None, None)

            else:
                img_feats = img_cam_front_feats
                with torch.cuda.stream(torch.cuda.default_stream(0)):
                    # use default cuda stream for lidar because spconv throws a error when using custom stream
                    camera_lidar_detector_pts_data = self.collect_camera_lidar_data(camera_lidar_data, cameras=[], use_lidar=True)
                    camera_lidar_detector_pts_data = self.camera_lidar_detector.pts_pipeline(camera_lidar_detector_pts_data)
                    points = [camera_lidar_detector_pts_data['points'][0].data.cuda(non_blocking=True)]
                    
                    # extract LiDAR features
                    # note: CMT's extract_pts_feat does not use img_feats or img_metas
                    pts_feats = self.camera_lidar_detector.detector.extract_pts_feat(points, None, None)
            # construct img_metas
            img_metas = [dict()]
            img_metas[0]['sample_idx'] = img_metas_cam_front[0]['sample_idx']
            img_metas[0]['box_type_3d'] = img_metas_cam_front[0]['box_type_3d']
            if len(selected_cams) >= 1:
                img_metas[0]['filename'] = img_metas_cam_front[0]['filename'] + img_metas_cam_other[0]['filename']
                img_metas[0]['lidar2img'] = img_metas_cam_front[0]['lidar2img'] + img_metas_cam_other[0]['lidar2img']
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape'] + img_metas_cam_other[0]['pad_shape']
            else:
                img_metas[0]['filename'] = img_metas_cam_front[0]['filename']
                img_metas[0]['lidar2img'] = img_metas_cam_front[0]['lidar2img']
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape']

            # note: CMT's transformer uses `lidar2img` of all cameras to build its queries
            img_metas[0]['lidar2img_query'] = self.get_img2lidar_query(camera_lidar_data['cam_intrinsic'], camera_lidar_data['lidar2cam'])

            # synchronize before transformer
            self.stream_camera.synchronize()

            # transformer and bbox decoder
            bbox_list = [dict() for i in range(len(img_metas))]
            bbox_pts = self.camera_lidar_detector.detector.simple_test_pts(
                pts_feats, img_feats, img_metas, rescale=False)
            for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                result_dict['pts_bbox'] = pts_bbox

            assert scene_classifier_token == img_metas_cam_front[0]['sample_idx'], f'{scene_classifier_token} != {img_metas_cam_front[0]["sample_idx"]}'

        else:
            # load and process the CAM_FRONT image for StreamPETR
            camera_only_detector_img_cam_front_data = self.collect_camera_only_data(camera_only_data, cameras=['CAM_FRONT'])
            camera_only_detector_img_cam_front_data = self.camera_only_detector.img_pipeline(camera_only_detector_img_cam_front_data)
            # camera_only_detector_img_cam_front_data = self.camera_only_detector.load_img(sample_idx, used_cameras=['CAM_FRONT'])
            img_metas_cam_front = [camera_only_detector_img_cam_front_data['img_metas'][0].data]
            img_cam_front = camera_only_detector_img_cam_front_data['img'][0].data.unsqueeze(0).cuda(non_blocking=True)

            assert scene_classifier_token == img_metas_cam_front[0]['sample_idx'], f'{scene_classifier_token} != {img_metas_cam_front[0]["sample_idx"]}'

            img_cam_front_feats = self.camera_only_detector.detector.extract_img_feat(img_cam_front, 1)
            selected_cams_mask, _ = self.camera_only_detector.detector.router_forward(img_cam_front_feats)

            nuscenes_all_cams = ['CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
            selected_cams = [nuscenes_all_cams[i] for i, x in enumerate(selected_cams_mask[0].to(torch.bool)) if x]

            for cam_name in nuscenes_all_cams:
                if cam_name in selected_cams:
                    self.camera_sensor_status_group[cam_name].activate()
                else:
                    self.camera_sensor_status_group[cam_name].deactivate()
                if self.wait_for_booting:
                    self.camera_sensor_status_group.wait_for_booting(selected_cams)
            selected_cams = list(filter(lambda x: self.camera_sensor_status_group[x].is_activated(), selected_cams))

            if len(selected_cams) >= 1:
                # load and process the router-selected images for StreamPETR
                camera_only_detector_img_cam_other_data = self.collect_camera_only_data(camera_only_data, cameras=selected_cams)
                camera_only_detector_img_cam_other_data = self.camera_only_detector.img_pipeline(camera_only_detector_img_cam_other_data)
                img_metas_cam_other = [camera_only_detector_img_cam_other_data['img_metas'][0].data]
                img_cam_other = camera_only_detector_img_cam_other_data['img'][0].data.unsqueeze(0).cuda(non_blocking=True)

                img_cam_other_feats = self.camera_only_detector.detector.extract_img_feat(img_cam_other, 1)

            # construct data
            data = dict()
            data['ego_pose'] = camera_only_detector_img_cam_front_data['ego_pose'][0].data.unsqueeze(0).cuda(non_blocking=True)
            data['ego_pose_inv'] = camera_only_detector_img_cam_front_data['ego_pose_inv'][0].data.unsqueeze(0).cuda(non_blocking=True)
            data['timestamp'] = camera_only_detector_img_cam_front_data['timestamp'][0].data.unsqueeze(0).cuda(non_blocking=True)
            if len(selected_cams) >= 1:
                data['img'] = torch.concat([img_cam_front, img_cam_other], dim=1)
                data['lidar2img'] = torch.concat([camera_only_detector_img_cam_front_data['lidar2img'][0].data, camera_only_detector_img_cam_other_data['lidar2img'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['intrinsics'] = torch.concat([camera_only_detector_img_cam_front_data['intrinsics'][0].data, camera_only_detector_img_cam_other_data['intrinsics'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['extrinsics'] = torch.concat([camera_only_detector_img_cam_front_data['extrinsics'][0].data, camera_only_detector_img_cam_other_data['extrinsics'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['img_timestamp'] = torch.concat([camera_only_detector_img_cam_front_data['img_timestamp'][0].data, camera_only_detector_img_cam_other_data['img_timestamp'][0].data], dim=0).unsqueeze(0).cuda(non_blocking=True)
                data['img_feats'] = []
                for img_cam_front_feat, img_cam_other_feat in zip(img_cam_front_feats, img_cam_other_feats):
                    img_feat = torch.concat([img_cam_front_feat, img_cam_other_feat], dim=1)
                    data['img_feats'].append(img_feat)
            else:
                data['img'] = img_cam_front
                data['lidar2img'] = camera_only_detector_img_cam_front_data['lidar2img'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['intrinsics'] = camera_only_detector_img_cam_front_data['intrinsics'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['extrinsics'] = camera_only_detector_img_cam_front_data['extrinsics'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['img_timestamp'] = camera_only_detector_img_cam_front_data['img_timestamp'][0].data.unsqueeze(0).cuda(non_blocking=True)
                data['img_feats'] = img_cam_front_feats

            # construct img_metas
            img_metas = [dict()]
            img_metas[0]['sample_idx'] = img_metas_cam_front[0]['sample_idx']
            img_metas[0]['scene_token'] = img_metas_cam_front[0]['scene_token']
            img_metas[0]['box_type_3d'] = img_metas_cam_front[0]['box_type_3d']
            if len(selected_cams) >= 1:
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape'] + img_metas_cam_other[0]['pad_shape']
            else:
                img_metas[0]['pad_shape'] = img_metas_cam_front[0]['pad_shape']

            # +1 is for CAM_FRONT
            # self.camera_only_detector.set_num_cams_valid(int(selected_cams_mask.sum().item()) + 1)
            self.camera_only_detector.set_num_cams_valid(len(selected_cams) + 1)
            bbox_list = [dict() for i in range(len(img_metas))]
            bbox_pts = self.camera_only_detector.detector.simple_test_pts(img_metas, **data)

            for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
                result_dict['pts_bbox'] = pts_bbox

        for result_dict in bbox_list:
            result_dict['metas'] = dict()
            result_dict['metas']['sample_token'] = scene_classifier_token
            result_dict['metas']['use_lidar'] = use_lidar
            result_dict['metas']['used_cameras'] = ['CAM_FRONT'] + selected_cams
            result_dict['metas']['time_data_loading'] = None
            result_dict['metas']['time_data_processing'] = None
            result_dict['metas']['infer_time_dict'] = None

        return bbox_list
