# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR3D (https://github.com/WangYueFt/detr3d)
# Copyright (c) 2021 Wang, Yue
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
#  Modified by Shihao Wang
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from projects.mmdet3d_plugin.models.utils.misc import locations
from projects.mmdet3d_plugin.utils import CUDATimer

from .repdetr3d import RepDetr3D


class GumbelRouter(nn.Module):
    def __init__(self,
                 num_classes: int,
                 input_dim: int,
                 tau: float = 1.0,
                 hard: bool = True,
                 bias_init: float = 0.) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.input_dim = input_dim
        self.tau = tau
        self.hard = hard

        self.gate = nn.Linear(input_dim, num_classes * 2)

        # in favor of using all modalities with custom bias init
        self.gate.bias.data = torch.tensor([0, bias_init], dtype=self.gate.bias.dtype).repeat(num_classes)

    def forward(self, x: torch.Tensor):
        # x: (batch_size, input_dim)
        B = x.shape[0]
        logits = self.gate(x)
        logits = logits.reshape(B, self.num_classes, 2)

        logits_g = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard, dim=2)
        logits_pos = logits_g[:, :, 1]
        return logits_pos


@DETECTORS.register_module()
class RepDetr3DWithGumbelModalRouter(RepDetr3D):

    def __init__(self,
                 router_tau=1.0,
                 router_hard=True,
                 router_bias_init=0.,
                 loss_router_aux_weight=0.02,
                 loss_router_aux_per_expert_weight=None,
                 use_mask=False,
                 masked_img_path=None,
                 masked_img_feats_path=None,
                 masked_pts_feats_path=None,
                 router_only=False,
                 **kwargs):
        super(RepDetr3DWithGumbelModalRouter, self).__init__(**kwargs)
        num_classes = 5 # 5 cameras
        self.router = GumbelRouter(
            num_classes=num_classes,
            input_dim=self.img_neck.out_channels,
            tau=router_tau,
            hard=router_hard,
            bias_init=router_bias_init,
        )
        self.loss_router_aux_weight = loss_router_aux_weight
        if loss_router_aux_per_expert_weight is None:
            self.loss_router_aux_per_expert_weight = torch.ones(num_classes)
        else:
            assert len(loss_router_aux_per_expert_weight) == num_classes
            self.loss_router_aux_per_expert_weight = torch.tensor(loss_router_aux_per_expert_weight)

        self.use_mask = use_mask
        if self.use_mask:
            raise NotImplementedError()
            if masked_img_path is None:
                raise ValueError('masked_img_path must be specified if using mask')

            self.masked_img = torch.load(masked_img_path, map_location='cpu')
        else:
            self.masked_img = None

        if router_only:
            for p in self.parameters():
                p.requires_grad = False
            for p in self.router.parameters():
                p.requires_grad = True

    def mask_feats(self,
                   img_feats,
                   feat_weights):
        B = feat_weights.shape[0]

        # never mask CAM_FRONT
        cam_front_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
        img_feat_weights = torch.concat([cam_front_weight, feat_weights], dim=1)

        if not self.use_mask:
            img_feats_new = []
            for img_feat in img_feats:
                img_feat = img_feat * img_feat_weights[:, None, :, None, None, None]
                img_feats_new.append(img_feat)

        else:
            raise NotImplementedError()

        return img_feats_new, img_feat_weights

    def loss_router_aux(self, used_expert: torch.Tensor) -> torch.Tensor:
        """
        loss_router_aux_v2

        支持专家权重的负载均衡损失函数
        
        Args:
            used_expert (torch.Tensor): 形状(batch_size, num_experts)的张量
                                        (取值为0.0或1.0)
            weights (Optional[torch.Tensor]): 形状(num_experts,)的专家权重张量
                                            默认全1，表示等权重
        
        Returns:
            torch.Tensor: 标量损失值
        
        示例:
            >>> used_expert = torch.tensor([[1,1,1,2]], dtype=torch.float)
            >>> weights = torch.tensor([1.0, 1.0, 1.0, 2.0])
            >>> loss = load_balancing_loss(used_expert, weights)
            >>> loss == 0.0  # 当选择次数符合权重比例时损失为0
        """
        num_experts = used_expert.size(1)
        
        expert_counts = torch.sum(used_expert, dim=0)  # (num_experts,)
        total_usage = torch.sum(expert_counts)
        
        # 无专家使用时返回0
        if total_usage == 0:
            return torch.tensor(0.0, device=used_expert.device)
        
        # 应用权重归一化
        weights = self.loss_router_aux_per_expert_weight.to(used_expert.device)
        adjusted_counts = expert_counts.float() / (weights + 1e-6)  # 防止除零
        mean = torch.mean(adjusted_counts)
        variance = torch.var(adjusted_counts, unbiased=False)
        
        # 当所有专家按权重比例使用时，方差应为0
        cv_squared = variance / (mean ** 2 + 1e-6)
        return cv_squared

    def router_forward(self, img_feats):
        B, _, N, C, H, W = img_feats[0].shape
        assert img_feats[0].shape[1] == 1

        router_inputs = img_feats[0][:, 0, 0] # assume img_feats[0].shape[1] == 1, (B, C, H, W)
        router_inputs = F.adaptive_avg_pool2d(router_inputs, output_size=(1, 1)) # (B, C, 1, 1)
        router_inputs = router_inputs[:, :, 0, 0] # (B, C)

        logits = self.router(router_inputs)
        return logits

    def count_sensor(self, feat_weights):
        with torch.no_grad():
            B = feat_weights.shape[0]

            # 1. never mask CAM_FRONT, 2. add LiDAR weights for logging, consistent with CMT
            cam_front_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
            lidar_weight = torch.zeros(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
            feat_weights = torch.concat([cam_front_weight, feat_weights, lidar_weight], dim=1)

            feat_weights = feat_weights.to(torch.float32)
            sensor_cnt = feat_weights.sum(dim=0)
            return sensor_cnt

    def extract_feat_with_mask(self, points, masked_points, img, img_metas):
        B = img.shape[0]
        merged_points = points + masked_points
        masked_img = self.masked_img.to(device=img.device, dtype=img.dtype).unsqueeze(0)
        merged_img = torch.concat([img, masked_img], dim=0)

        merged_img_feats, merged_pts_feats = self.extract_feat(
            merged_points, img=merged_img, img_metas=img_metas)

        pts_feats, masked_pts_feats = [], []
        for merged_pts_feat in merged_pts_feats:
            pts_feats.append(merged_pts_feat[:B])
            masked_pts_feats.append(merged_pts_feat[B:])

        img_feats, masked_img_feats = (), ()
        for merged_img_feat in merged_img_feats:
            N = merged_img_feat.shape[0] // (B + 1)
            img_feats += (merged_img_feat[:B*N],)
            masked_img_feats += (merged_img_feat[B*N:],)

        return img_feats, pts_feats, masked_img_feats, masked_pts_feats

    def forward_train(self,
                      img_metas=None,
                      gt_bboxes_3d=None,
                      gt_labels_3d=None,
                      gt_labels=None,
                      gt_bboxes=None,
                      gt_bboxes_ignore=None,
                      depths=None,
                      centers2d=None,
                      **data):
        """Forward training function.

        Args:
            points (list[torch.Tensor], optional): Points of each sample.
                Defaults to None.
            img_metas (list[dict], optional): Meta information of each sample.
                Defaults to None.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
                Ground truth 3D boxes. Defaults to None.
            gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
                of 3D boxes. Defaults to None.
            gt_labels (list[torch.Tensor], optional): Ground truth labels
                of 2D boxes in images. Defaults to None.
            gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
                images. Defaults to None.
            img (torch.Tensor optional): Images of each sample with shape
                (N, C, H, W). Defaults to None.
            proposals ([list[torch.Tensor], optional): Predicted proposals
                used for training Fast RCNN. Defaults to None.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                2D boxes in images to be ignored. Defaults to None.

        Returns:
            dict: Losses of different branches.
        """

        if self.test_flag: #for interval evaluation
            self.pts_bbox_head.reset_memory()
            self.test_flag = False

        T = data['img'].size(1)

        prev_img = data['img'][:, :-self.num_frame_backbone_grads]
        rec_img = data['img'][:, -self.num_frame_backbone_grads:]
        rec_img_feats = self.extract_feat(rec_img, self.num_frame_backbone_grads)

        logits = self.router_forward(rec_img_feats)
        rec_img_feats, img_feat_weights = self.mask_feats(rec_img_feats, logits)

        if T-self.num_frame_backbone_grads > 0:
            raise NotImplementedError()
            self.eval()
            with torch.no_grad():
                prev_img_feats = self.extract_feat(prev_img, T-self.num_frame_backbone_grads, True)
            self.train()
            data['img_feats'] = [torch.cat([prev_img_feats[i], rec_img_feats[i]], dim=1) for i in range(len(self.position_level))]
        else:
            data['img_feats'] = rec_img_feats
            data['valid_imgs'] = img_feat_weights.to(torch.bool)

        losses = self.obtain_history_memory(gt_bboxes_3d,
                        gt_labels_3d, gt_bboxes,
                        gt_labels, img_metas, centers2d, depths, gt_bboxes_ignore, **data)

        loss_router_aux = self.loss_router_aux(logits) * self.loss_router_aux_weight
        losses.update({'loss_router_aux': loss_router_aux})

        with torch.no_grad():
            sensor_cnt = self.count_sensor(logits)
            dist.all_reduce(sensor_cnt)
            global_batch_size = rec_img_feats[0].shape[0] * dist.get_world_size()
            sensor_ratio = sensor_cnt / global_batch_size
            for i in range(len(sensor_ratio)):
                losses.update({f'sensor.{i}.ratio': sensor_ratio[i]})

        return losses


    def forward_test(self,
                     points=None,
                     masked_points=None,
                     img_metas=None,
                     img=None, **kwargs):
        """
        Args:
            points (list[torch.Tensor]): the outer list indicates test-time
                augmentations and inner torch.Tensor should have a shape NxC,
                which contains all points in the batch.
            img_metas (list[list[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
            img (list[torch.Tensor], optional): the outer
                list indicates test-time augmentations and inner
                torch.Tensor should have a shape NxCxHxW, which contains
                all images in the batch. Defaults to None.
        """
        if points is None:
            points = [None]
        if img is None:
            img = [None]
        for var, name in [(points, 'points'), (img, 'img'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        return self.simple_test(points[0], masked_points[0], img_metas[0], img[0], **kwargs)

    def simple_test(self, img_metas, **data):
        """Test function without augmentaiton."""
        rec_img_feats = self.extract_img_feat(data['img'], 1)
        logits = self.router_forward(rec_img_feats)
        rec_img_feats, img_feat_weights = self.mask_feats(rec_img_feats, logits)

        data['img_feats'] = rec_img_feats
        data['valid_imgs'] = img_feat_weights.to(torch.bool)

        bbox_list = [dict() for i in range(len(img_metas))]
        bbox_pts = self.simple_test_pts(
            img_metas, **data)
        for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
            result_dict['pts_bbox'] = pts_bbox
        return bbox_list

    def forward_test_sensor_ratio(self, img, img_metas):
        img_feats = self.extract_img_feat(img, img_metas)
        weights_sampled, _ = self.router_forward(img, img_feats)

        sensor_cnt = self.count_sensor(weights_sampled)
        batch_size = img.shape[0]
        sensor_ratio = sensor_cnt / batch_size
        return sensor_ratio, sensor_cnt

    def simple_test_time(self, points, img_metas, img=None, rescale=False):
        assert img.shape[0] == 1
        assert self.masked_img_feats is not None
        assert self.masked_pts_feats is not None

        masked_img_feats = ()
        masked_pts_feats = []
        for i in range(len(self.masked_img_feats)):
            masked_img_feats += (self.masked_img_feats[i].cuda(),)
        for i in range(len(self.masked_pts_feats)):
            masked_pts_feats.append(self.masked_pts_feats[i].cuda())
        self.masked_img_feats = masked_img_feats
        self.masked_pts_feats = masked_pts_feats

        time_dict = dict()

        if self.force_no_lidar:
            with CUDATimer('mask_lidar_force') as t:
                dtype = points[0].dtype
                device = points[0].device
                points = [torch.zeros(1, 5, dtype=dtype, device=device)]
            time_dict = self._merge_dict(time_dict, t.get_time_dict())

        if self.force_empty_image:
            with CUDATimer('mask_img_force') as t:
                img = torch.zeros_like(img)
            time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('extract_feat_cam_front') as t:
            img_cam_front = img[:, 0:1]
            img_cam_front_feat = self.extract_img_feat(img_cam_front, img_metas)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('router') as t:
            router_inputs = F.adaptive_avg_pool2d(img_cam_front_feat[0], output_size=(1, 1)) # (B, C, 1, 1)
            router_inputs = router_inputs[:, :, 0, 0] # (B, C)
            logits = self.router(router_inputs)
            weights = F.softmax(logits, dim=1)
            mask = top_p_sampling(weights, self.top_p_threshold, binary=True)[0].to(torch.bool)
            camera_mask = torch.zeros(6, dtype=torch.bool, device=mask.device)
            camera_mask[1:] = mask[:-1] # camera_mask[0] is always False
            use_lidar = mask[-1].item()
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('extract_feat_cam_other') as t:
            img_cam_other = img[:, camera_mask]
            img_cam_other_feat = self.extract_img_feat(img_cam_other, img_metas)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('construct_cam_feat') as t:
            img_feats = ()
            for i in range(len(img_cam_front_feat)):
                img_feat = [img_cam_front_feat[i][0]]
                cam_other_idx = 0
                for use_cam in camera_mask[1:]:
                    if use_cam:
                        img_feat.append(img_cam_other_feat[i][cam_other_idx])
                        cam_other_idx += 1
                    else:
                        img_feat.append(masked_img_feats[i][0])
                img_feat = torch.stack(img_feat, dim=0)
                img_feats += (img_feat,)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        if use_lidar:
            with CUDATimer('extract_lidar_feat') as t:
                pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
            time_dict = self._merge_dict(time_dict, t.get_time_dict())
        else:
            pts_feats = masked_pts_feats
            time_dict['extract_lidar_feat'] = 0.0

        if pts_feats is None:
            pts_feats = [None]
        if img_feats is None:
            img_feats = [None]

        if (pts_feats or img_feats) and self.with_pts_bbox:
            time_dict_pts = self.simple_test_pts_time(
                pts_feats, img_feats, img_metas, rescale=rescale)
            time_dict = self._merge_dict(time_dict, time_dict_pts)
        else:
            raise NotImplementedError()

        sensor_mask = torch.ones(7, dtype=torch.bool, device=mask.device)
        sensor_mask[1:] = mask
        sensor_mask = sensor_mask.tolist()

        return {'time_dict': time_dict, 'sensor_mask': sensor_mask}


@DETECTORS.register_module()
class RepDetr3DWithTopPModalRouter(RepDetr3D):

    def __init__(self,
                 top_p_threshold=0.9,
                 loss_router_dynamic_weight=1e-4,
                 loss_router_balance_weight=1e-2,
                 use_mask=False,
                 masked_img_path=None,
                 masked_img_feats_path=None,
                 masked_pts_feats_path=None,
                 router_only=False,
                 version='v1',
                 **kwargs):
        super(RepDetr3DWithTopPModalRouter, self).__init__(**kwargs)
        num_classes = 5 # 5 cameras
        self.router = nn.Linear(self.img_neck.out_channels, num_classes, bias=False)
        self.top_p_threshold = top_p_threshold
        self.loss_router_dynamic_weight = loss_router_dynamic_weight
        self.loss_router_balance_weight = loss_router_balance_weight

        self.use_mask = use_mask
        if self.use_mask:
            raise NotImplementedError()
            if masked_img_path is None:
                raise ValueError('masked_img_path must be specified if using mask')

            self.masked_img = torch.load(masked_img_path, map_location='cpu')
        else:
            self.masked_img = None

        self.masked_img_feats = None
        if masked_img_feats_path is not None:
            self.masked_img_feats = torch.load(masked_img_feats_path, map_location='cpu')
        else:
            self.masked_img_feats = None

        self.masked_pts_feats = None
        if masked_pts_feats_path is not None:
            self.masked_pts_feats = torch.load(masked_pts_feats_path, map_location='cpu')
        else:
            self.masked_pts_feats = None

        self.version = version
        assert self.version in ['v1', 'v2', 'v3']
        assert self.version == 'v1', 'only v1 is supported'
        if self.version == 'v2':
            self.masked_img_token = nn.Parameter(torch.zeros(self.img_neck.out_channels))
            nn.init.normal_(self.masked_img_token, std=.02)
            self.router_img_conv = nn.Conv2d(6, 1, kernel_size=1, bias=False) # 6 is the number of images
        elif self.version == 'v3':
            self.masked_img_token = nn.Parameter(torch.zeros(self.img_neck.out_channels))
            nn.init.normal_(self.masked_img_token, std=.02)
            self.router_img_conv = nn.Conv2d(6, 1, kernel_size=1, bias=False) # 6 is the number of images
            self.router_pixel_conv = nn.Conv2d(self.img_neck.out_channels, self.img_neck.out_channels, kernel_size=1, bias=False) # 6 is the number of images

        if router_only:
            for p in self.parameters():
                p.requires_grad = False
            for p in self.router.parameters():
                p.requires_grad = True
            if self.version in ['v2', 'v3']:
                self.masked_img_token.requires_grad = True
                for p in self.router_img_conv.parameters():
                    p.requires_grad = True
            if self.version == 'v3':
                for p in self.router_pixel_conv.parameters():
                    p.requires_grad = True

    def mask_feats(self,
                   img_feats,
                   feat_weights):
        B = feat_weights.shape[0]

        # never mask CAM_FRONT
        cam_front_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
        img_feat_weights = torch.concat([cam_front_weight, feat_weights], dim=1)

        if not self.use_mask:
            img_feats_new = []
            for img_feat in img_feats:
                if self.training:
                    img_feat = img_feat * img_feat_weights[:, None, :, None, None, None]
                else:
                    img_feat = img_feat * img_feat_weights[:, :, None, None, None]
                img_feats_new.append(img_feat)

        else:
            raise NotImplementedError()

        return img_feats_new, img_feat_weights

    def router_forward(self, img_feats):
        if self.version == 'v1':
            return self.router_forward_v1(img_feats)
        else:
            raise NotImplementedError(f'Unsupported router version: {self.version}')

    def router_forward_v1(self, img_feats):
        if img_feats[0].dim() == 6:
            B, T, N, C, H, W = img_feats[0].shape
            router_inputs = img_feats[0][:, 0, 0] # assume img_feats[0].shape[1] == 1, (B, C, H, W)
        else:
            T, N, C, H, W = img_feats[0].shape
            router_inputs = img_feats[0][0, 0].unsqueeze(0) # assume img_feats[0].shape[1] == 1, (B, C, H, W)
            B = 1
        assert T == 1

        router_inputs = F.adaptive_avg_pool2d(router_inputs, output_size=(1, 1)) # (B, C, 1, 1)
        router_inputs = router_inputs[:, :, 0, 0] # (B, C)
        router_inputs = router_inputs.to(self.router.weight.dtype)

        logits = self.router(router_inputs)
        weights = F.softmax(logits, dim=1)
        weights_sampled = top_p_sampling(weights, self.top_p_threshold, binary=True, at_least_one=True)

        return weights_sampled, weights

    def router_forward_v2(self, img, img_feats):
        if img.dim() == 4:
            B = 1
            N, _, _, _ = img.shape
        else:
            B, N, _, _, _ = img.shape

        _, C, H, W = img_feats[0].shape
        router_inputs = img_feats[0].reshape(B, N, C, H, W)
        router_inputs = router_inputs.permute(0, 2, 1, 3, 4) # (B, C, N, H, W)

        # random masking
        # FIXME: do we need this in testing?
        mask = (torch.rand(B, N, device=img.device) >= 0.5).float()
        mask[:, 0] = 1.0 # never mask CAM_FRONT
        mask = mask[:, None, :, None, None].expand(B, C, N, H, W)
        masked_img_token = self.masked_img_token[None, :, None, None, None].expand(B, C, N, H, W)
        router_inputs = router_inputs * mask + (1.0 - mask) * masked_img_token

        router_inputs = router_inputs.reshape(B * C, N, H, W)
        router_inputs = self.router_img_conv(router_inputs).squeeze(1) # (B * C, H, W)
        router_inputs = router_inputs.reshape(B, C, H, W)

        router_inputs = F.adaptive_avg_pool2d(router_inputs, output_size=(1, 1)) # (B, C, 1, 1)
        router_inputs = router_inputs[:, :, 0, 0] # (B, C)

        logits = self.router(router_inputs)
        weights = F.softmax(logits, dim=1)
        weights_sampled = top_p_sampling(weights, self.top_p_threshold, binary=True, at_least_one=True)

        return weights_sampled, weights

    def router_forward_v3(self, img, img_feats):
        if img.dim() == 4:
            B = 1
            N, _, _, _ = img.shape
        else:
            B, N, _, _, _ = img.shape

        _, C, H, W = img_feats[0].shape
        router_inputs = img_feats[0]
        router_inputs = self.router_pixel_conv(router_inputs)
        router_inputs = img_feats[0].reshape(B, N, C, H, W)
        router_inputs = router_inputs.permute(0, 2, 1, 3, 4) # (B, C, N, H, W)

        # random masking
        # FIXME: do we need this in testing?
        mask = (torch.rand(B, N, device=img.device) >= 0.5).float()
        mask[:, 0] = 1.0 # never mask CAM_FRONT
        mask = mask[:, None, :, None, None].expand(B, C, N, H, W)
        masked_img_token = self.masked_img_token[None, :, None, None, None].expand(B, C, N, H, W)
        router_inputs = router_inputs * mask + (1.0 - mask) * masked_img_token

        router_inputs = router_inputs.reshape(B * C, N, H, W)
        router_inputs = self.router_img_conv(router_inputs).squeeze(1) # (B * C, H, W)
        router_inputs = router_inputs.reshape(B, C, H, W)

        router_inputs = F.adaptive_avg_pool2d(router_inputs, output_size=(1, 1)) # (B, C, 1, 1)
        router_inputs = router_inputs[:, :, 0, 0] # (B, C)

        logits = self.router(router_inputs)
        weights = F.softmax(logits, dim=1)
        weights_sampled = top_p_sampling(weights, self.top_p_threshold, binary=True, at_least_one=True)

        return weights_sampled, weights

    def dynamic_loss(self, weights):
        return -(weights * torch.log(weights)).sum(dim=1).mean()

    def load_balancing_loss(self, weights: torch.Tensor, weights_sampled: torch.Tensor):
        assert (weights_sampled == 0.).sum() + (weights_sampled == 1.).sum() == weights_sampled.numel()

        _, num_experts = weights.shape
        expert_mask = weights_sampled.int()
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
        router_prob_per_expert = torch.mean(weights, dim=0)
        loss = torch.sum(tokens_per_expert * router_prob_per_expert) * num_experts
        return loss

    def count_sensor(self, feat_weights):
        with torch.no_grad():
            B = feat_weights.shape[0]

            # 1. never mask CAM_FRONT, 2. add LiDAR weights for logging, consistent with CMT
            cam_front_weight = torch.ones(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
            lidar_weight = torch.zeros(B, 1, dtype=feat_weights.dtype, device=feat_weights.device)
            feat_weights = torch.concat([cam_front_weight, feat_weights, lidar_weight], dim=1)

            feat_weights = feat_weights.to(torch.float32)
            sensor_cnt = feat_weights.sum(dim=0)
            return sensor_cnt

    def extract_feat_with_mask(self, points, masked_points, img, img_metas):
        B = img.shape[0]
        merged_points = points + masked_points
        masked_img = self.masked_img.to(device=img.device, dtype=img.dtype).unsqueeze(0)
        merged_img = torch.concat([img, masked_img], dim=0)

        merged_img_feats, merged_pts_feats = self.extract_feat(
            merged_points, img=merged_img, img_metas=img_metas)

        pts_feats, masked_pts_feats = [], []
        for merged_pts_feat in merged_pts_feats:
            pts_feats.append(merged_pts_feat[:B])
            masked_pts_feats.append(merged_pts_feat[B:])

        img_feats, masked_img_feats = (), ()
        for merged_img_feat in merged_img_feats:
            N = merged_img_feat.shape[0] // (B + 1)
            img_feats += (merged_img_feat[:B*N],)
            masked_img_feats += (merged_img_feat[B*N:],)

        return img_feats, pts_feats, masked_img_feats, masked_pts_feats

    def forward_train(self,
                      img_metas=None,
                      gt_bboxes_3d=None,
                      gt_labels_3d=None,
                      gt_labels=None,
                      gt_bboxes=None,
                      gt_bboxes_ignore=None,
                      depths=None,
                      centers2d=None,
                      **data):
        """Forward training function.

        Args:
            points (list[torch.Tensor], optional): Points of each sample.
                Defaults to None.
            img_metas (list[dict], optional): Meta information of each sample.
                Defaults to None.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
                Ground truth 3D boxes. Defaults to None.
            gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
                of 3D boxes. Defaults to None.
            gt_labels (list[torch.Tensor], optional): Ground truth labels
                of 2D boxes in images. Defaults to None.
            gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
                images. Defaults to None.
            img (torch.Tensor optional): Images of each sample with shape
                (N, C, H, W). Defaults to None.
            proposals ([list[torch.Tensor], optional): Predicted proposals
                used for training Fast RCNN. Defaults to None.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                2D boxes in images to be ignored. Defaults to None.

        Returns:
            dict: Losses of different branches.
        """

        if self.test_flag: #for interval evaluation
            self.pts_bbox_head.reset_memory()
            self.test_flag = False

        T = data['img'].size(1)

        prev_img = data['img'][:, :-self.num_frame_backbone_grads]
        rec_img = data['img'][:, -self.num_frame_backbone_grads:]
        rec_img_feats = self.extract_feat(rec_img, self.num_frame_backbone_grads)

        weights_sampled, weights = self.router_forward(rec_img_feats)
        rec_img_feats, img_feat_weights = self.mask_feats(rec_img_feats, weights_sampled)

        if T-self.num_frame_backbone_grads > 0:
            raise NotImplementedError()
            self.eval()
            with torch.no_grad():
                prev_img_feats = self.extract_feat(prev_img, T-self.num_frame_backbone_grads, True)
            self.train()
            data['img_feats'] = [torch.cat([prev_img_feats[i], rec_img_feats[i]], dim=1) for i in range(len(self.position_level))]
        else:
            data['img_feats'] = rec_img_feats
            data['valid_imgs'] = img_feat_weights.to(torch.bool)

        losses = self.obtain_history_memory(gt_bboxes_3d,
                        gt_labels_3d, gt_bboxes,
                        gt_labels, img_metas, centers2d, depths, gt_bboxes_ignore, **data)

        loss_router_dynamic = self.dynamic_loss(weights) * self.loss_router_dynamic_weight
        loss_router_balance = self.load_balancing_loss(weights, weights_sampled) * self.loss_router_balance_weight
        losses.update({'loss_router_dynamic': loss_router_dynamic, 'loss_router_balance': loss_router_balance})

        with torch.no_grad():
            sensor_cnt = self.count_sensor(weights_sampled)
            dist.all_reduce(sensor_cnt)
            global_batch_size = rec_img_feats[0].shape[0] * dist.get_world_size()
            sensor_ratio = sensor_cnt / global_batch_size
            for i in range(len(sensor_ratio)):
                losses.update({f'sensor.{i}.ratio': sensor_ratio[i]})

        return losses

    def simple_test(self, img_metas, **data):
        """Test function without augmentaiton."""
        rec_img_feats = self.extract_img_feat(data['img'], 1)
        weights_sampled, weights = self.router_forward(rec_img_feats)
        rec_img_feats, img_feat_weights = self.mask_feats(rec_img_feats, weights_sampled)

        data['img_feats'] = rec_img_feats
        data['valid_imgs'] = img_feat_weights.to(torch.bool)

        bbox_list = [dict() for i in range(len(img_metas))]
        bbox_pts = self.simple_test_pts(
            img_metas, **data)
        for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
            result_dict['pts_bbox'] = pts_bbox
        return bbox_list

    def forward_test_sensor_ratio(self, img, img_metas):
        img_feats = self.extract_img_feat(img, img_metas)
        weights_sampled, _ = self.router_forward(img, img_feats)

        sensor_cnt = self.count_sensor(weights_sampled)
        batch_size = img.shape[0]
        sensor_ratio = sensor_cnt / batch_size
        return sensor_ratio, sensor_cnt

    def simple_test_time(self, points, img_metas, img=None, rescale=False):
        assert img.shape[0] == 1
        assert self.masked_img_feats is not None
        assert self.masked_pts_feats is not None

        masked_img_feats = ()
        masked_pts_feats = []
        for i in range(len(self.masked_img_feats)):
            masked_img_feats += (self.masked_img_feats[i].cuda(),)
        for i in range(len(self.masked_pts_feats)):
            masked_pts_feats.append(self.masked_pts_feats[i].cuda())
        self.masked_img_feats = masked_img_feats
        self.masked_pts_feats = masked_pts_feats

        time_dict = dict()

        if self.force_no_lidar:
            with CUDATimer('mask_lidar_force') as t:
                dtype = points[0].dtype
                device = points[0].device
                points = [torch.zeros(1, 5, dtype=dtype, device=device)]
            time_dict = self._merge_dict(time_dict, t.get_time_dict())

        if self.force_empty_image:
            with CUDATimer('mask_img_force') as t:
                img = torch.zeros_like(img)
            time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('extract_feat_cam_front') as t:
            img_cam_front = img[:, 0:1]
            img_cam_front_feat = self.extract_img_feat(img_cam_front, img_metas)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('router') as t:
            router_inputs = F.adaptive_avg_pool2d(img_cam_front_feat[0], output_size=(1, 1)) # (B, C, 1, 1)
            router_inputs = router_inputs[:, :, 0, 0] # (B, C)
            logits = self.router(router_inputs)
            weights = F.softmax(logits, dim=1)
            mask = top_p_sampling(weights, self.top_p_threshold, binary=True)[0].to(torch.bool)
            camera_mask = torch.zeros(6, dtype=torch.bool, device=mask.device)
            camera_mask[1:] = mask[:-1] # camera_mask[0] is always False
            use_lidar = mask[-1].item()
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('extract_feat_cam_other') as t:
            img_cam_other = img[:, camera_mask]
            img_cam_other_feat = self.extract_img_feat(img_cam_other, img_metas)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        with CUDATimer('construct_cam_feat') as t:
            img_feats = ()
            for i in range(len(img_cam_front_feat)):
                img_feat = [img_cam_front_feat[i][0]]
                cam_other_idx = 0
                for use_cam in camera_mask[1:]:
                    if use_cam:
                        img_feat.append(img_cam_other_feat[i][cam_other_idx])
                        cam_other_idx += 1
                    else:
                        img_feat.append(masked_img_feats[i][0])
                img_feat = torch.stack(img_feat, dim=0)
                img_feats += (img_feat,)
        time_dict = self._merge_dict(time_dict, t.get_time_dict())

        if use_lidar:
            with CUDATimer('extract_lidar_feat') as t:
                pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
            time_dict = self._merge_dict(time_dict, t.get_time_dict())
        else:
            pts_feats = masked_pts_feats
            time_dict['extract_lidar_feat'] = 0.0

        if pts_feats is None:
            pts_feats = [None]
        if img_feats is None:
            img_feats = [None]

        if (pts_feats or img_feats) and self.with_pts_bbox:
            time_dict_pts = self.simple_test_pts_time(
                pts_feats, img_feats, img_metas, rescale=rescale)
            time_dict = self._merge_dict(time_dict, time_dict_pts)
        else:
            raise NotImplementedError()

        sensor_mask = torch.ones(7, dtype=torch.bool, device=mask.device)
        sensor_mask[1:] = mask
        sensor_mask = sensor_mask.tolist()

        return {'time_dict': time_dict, 'sensor_mask': sensor_mask}


def top_p_sampling(logits, top_p=0.9, temperature=1.0, binary=False, at_least_one=False):
    """
    Apply Top-p sampling to every element in the sequence for each item in the batch.
    Returns the selected token indices and the corresponding threshold indices.
    
    :param logits: Logits from a language model with shape (sequence length, batch size, L)
    :param top_p: Cumulative probability threshold (float)
    :param temperature: Sampling temperature (float)
    :return: Tuple of tensors (selected token indices, threshold indices) for each position in each sequence in the batch
    """
    # Apply temperature
    logits_t = logits / temperature
    
    # Convert logits to probabilities
    # probabilities = torch.softmax(logits, dim=-1)
    # Sort probabilities and their indices in descending order
    sorted_probs, sorted_indices = torch.sort(logits_t, descending=True)

    # Compute cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    mask = cumulative_probs > top_p

    if at_least_one:
        mask[:, 0] = False

    batch_indices = torch.arange(logits.shape[0], device=logits.device).unsqueeze(1)
    inverse_indices = torch.argsort(sorted_indices)
    mask = mask[batch_indices, inverse_indices]

    if binary:
        logits = ones_like_with_grad(logits)
    return logits * torch.logical_not(mask)

class OnesLikeWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor):
        output = torch.ones_like(input_tensor)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

def ones_like_with_grad(input_tensor):
    return OnesLikeWithGrad.apply(input_tensor)
