# ------------------------------------------------------------------------
# Copyright (c) 2023 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------

from distutils.command.build import build
import enum
from turtle import down
import math
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import sigmoid_focal_loss
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding, build_transformer_layer_sequence
from mmcv.runner import BaseModule, force_fp32
from mmcv.cnn import xavier_init, constant_init, kaiming_init
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
                        build_assigner, build_sampler, multi_apply,
                        reduce_mean, build_bbox_coder)
from mmdet.models.utils import build_transformer
from mmdet.models import HEADS, LOSSES, build_loss
from mmdet.models.utils import NormedLinear
from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet3d.models.utils.clip_sigmoid import clip_sigmoid
from mmdet3d.models import builder
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
                          xywhr2xyxyr)
from einops import rearrange
import collections

from functools import reduce
from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox

from timm.models.layers import trunc_normal_

from typing import Optional


def pos2embed(pos, num_pos_feats=128, temperature=10000):
    scale = 2 * math.pi
    pos = pos * scale
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
    dim_t = 2 * (dim_t // 2) / num_pos_feats + 1
    pos_x = pos[..., 0, None] / dim_t
    pos_y = pos[..., 1, None] / dim_t
    pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
    pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
    posemb = torch.cat((pos_y, pos_x), dim=-1)
    return posemb


class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, groups, eps):
        ctx.groups = groups
        ctx.eps = eps
        N, C, L = x.size()
        x = x.view(N, groups, C // groups, L)
        mu = x.mean(2, keepdim=True)
        var = (x - mu).pow(2).mean(2, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1) * y.view(N, C, L) + bias.view(1, C, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        groups = ctx.groups
        eps = ctx.eps

        N, C, L = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1)
        g = g.view(N, groups, C//groups, L)
        mean_g = g.mean(dim=2, keepdim=True)
        mean_gy = (g * y).mean(dim=2, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx.view(N, C, L), (grad_output * y.view(N, C, L)).sum(dim=2).sum(dim=0), grad_output.sum(dim=2).sum(
            dim=0), None, None


class GroupLayerNorm1d(nn.Module):

    def __init__(self, channels, groups=1, eps=1e-6):
        super(GroupLayerNorm1d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.groups = groups
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.groups, self.eps)


@HEADS.register_module()
class CmtSceneClassifierHead(BaseModule):

    def __init__(self,
                 in_channels,
                 num_classes,
                 pc_range,
                 num_query=900,
                 hidden_dim=128,
                 depth_num=64,
                 norm_bbox=True,
                 downsample_scale=8,
                 scalar=10,
                 noise_scale=1.0,
                 noise_trans=0.0,
                 dn_weight=1.0,
                 split=0.75,
                 train_cfg=None,
                 test_cfg=None,
                 transformer=None,
                 loss_cls=dict(
                     type="FocalLoss",
                     use_sigmoid=True,
                     reduction="mean",
                     gamma=2, alpha=0.25, loss_weight=1.0
                 ),
                 mask_unimportant_devices=False,
                 init_cfg=None,
                 **kwargs):
        assert init_cfg is None
        super(CmtSceneClassifierHead, self).__init__(init_cfg=init_cfg)
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.num_query = num_query
        self.in_channels = in_channels
        self.depth_num = depth_num
        self.norm_bbox = norm_bbox
        self.downsample_scale = downsample_scale
        self.scalar = scalar
        self.bbox_noise_scale = noise_scale
        self.bbox_noise_trans = noise_trans
        self.dn_weight = dn_weight
        self.split = split
        self.pc_range = pc_range

        self.loss_cls = build_loss(loss_cls)
        self.mask_unimportant_devices = mask_unimportant_devices
        self.fp16_enabled = False
           
        self.shared_conv = ConvModule(
            in_channels,
            hidden_dim,
            kernel_size=3,
            padding=1,
            conv_cfg=dict(type="Conv2d"),
            norm_cfg=dict(type="BN2d")
        )
        
        # transformer
        self.transformer = build_transformer(transformer)
        self.reference_points = nn.Embedding(num_query, 3)
        self.bev_embedding = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.rv_embedding = nn.Sequential(
            nn.Linear(self.depth_num * 3, self.hidden_dim * 4),
            nn.ReLU(inplace=True),
            nn.Linear(self.hidden_dim * 4, self.hidden_dim)
        )

        # cls tokens
        self.num_cls_tokens = num_classes
        self.cls_tokens = nn.Parameter(torch.zeros(1, self.num_cls_tokens, hidden_dim))

        # head
        head = []
        for i in range(self.num_classes):
            _head = nn.Linear(self.hidden_dim, 1)
            head.append(_head)
        self.head = nn.ModuleList(head)

    def init_weights(self):
        super(CmtSceneClassifierHead, self).init_weights()
        nn.init.uniform_(self.reference_points.weight.data, 0, 1)

        # initialization from deit
        trunc_normal_(self.cls_tokens, std=.02)
        for h in self.head:
            trunc_normal_(h.weight, std=.02)
            if h.bias is not None:
                nn.init.constant_(h.bias, 0)

    @property
    def coords_bev(self):
        cfg = self.train_cfg if self.train_cfg else self.test_cfg
        x_size, y_size = (
            cfg['grid_size'][1] // self.downsample_scale,
            cfg['grid_size'][0] // self.downsample_scale
        )
        meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
        batch_y, batch_x = torch.meshgrid(*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid])
        batch_x = (batch_x + 0.5) / x_size
        batch_y = (batch_y + 0.5) / y_size
        coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)
        coord_base = coord_base.view(2, -1).transpose(1, 0) # (H*W, 2)
        return coord_base

    def prepare_for_dn(self, batch_size, reference_points, img_metas):
        padded_reference_points = reference_points.unsqueeze(0).repeat(batch_size, 1, 1)
        attn_mask = None
        mask_dict = None

        return padded_reference_points, attn_mask, mask_dict

    def _rv_pe(self, img_feats, img_metas):
        BN, C, H, W = img_feats.shape
        pad_h, pad_w, _ = img_metas[0]['pad_shape'][0]
        coords_h = torch.arange(H, device=img_feats[0].device).float() * pad_h / H
        coords_w = torch.arange(W, device=img_feats[0].device).float() * pad_w / W
        coords_d = 1 + torch.arange(self.depth_num, device=img_feats[0].device).float() * (self.pc_range[3] - 1) / self.depth_num
        coords_h, coords_w, coords_d = torch.meshgrid([coords_h, coords_w, coords_d])

        coords = torch.stack([coords_w, coords_h, coords_d, coords_h.new_ones(coords_h.shape)], dim=-1)
        coords[..., :2] = coords[..., :2] * coords[..., 2:3]
        
        imgs2lidars = np.concatenate([np.linalg.inv(meta['lidar2img']) for meta in img_metas])
        imgs2lidars = torch.from_numpy(imgs2lidars).float().to(coords.device)
        coords_3d = torch.einsum('hwdo, bco -> bhwdc', coords, imgs2lidars)
        coords_3d = (coords_3d[..., :3] - coords_3d.new_tensor(self.pc_range[:3])[None, None, None, :] )\
                        / (coords_3d.new_tensor(self.pc_range[3:]) - coords_3d.new_tensor(self.pc_range[:3]))[None, None, None, :]
        return self.rv_embedding(coords_3d.reshape(*coords_3d.shape[:-2], -1))

    def _bev_query_embed(self, ref_points, img_metas):
        bev_embeds = self.bev_embedding(pos2embed(ref_points, num_pos_feats=self.hidden_dim))
        return bev_embeds

    def _rv_query_embed(self, ref_points, img_metas):
        pad_h, pad_w, _ = img_metas[0]['pad_shape'][0]
        lidars2imgs = np.stack([meta['lidar2img'] for meta in img_metas])
        lidars2imgs = torch.from_numpy(lidars2imgs).float().to(ref_points.device)
        imgs2lidars = np.stack([np.linalg.inv(meta['lidar2img']) for meta in img_metas])
        imgs2lidars = torch.from_numpy(imgs2lidars).float().to(ref_points.device)

        ref_points = ref_points * (ref_points.new_tensor(self.pc_range[3:]) - ref_points.new_tensor(self.pc_range[:3])) + ref_points.new_tensor(self.pc_range[:3])
        proj_points = torch.einsum('bnd, bvcd -> bvnc', torch.cat([ref_points, ref_points.new_ones(*ref_points.shape[:-1], 1)], dim=-1), lidars2imgs)
        
        proj_points_clone = proj_points.clone()
        z_mask = proj_points_clone[..., 2:3].detach() > 0
        proj_points_clone[..., :3] = proj_points[..., :3] / (proj_points[..., 2:3].detach() + z_mask * 1e-6 - (~z_mask) * 1e-6) 
        # proj_points_clone[..., 2] = proj_points.new_ones(proj_points[..., 2].shape) 
        
        mask = (proj_points_clone[..., 0] < pad_w) & (proj_points_clone[..., 0] >= 0) & (proj_points_clone[..., 1] < pad_h) & (proj_points_clone[..., 1] >= 0)
        mask &= z_mask.squeeze(-1)

        coords_d = 1 + torch.arange(self.depth_num, device=ref_points.device).float() * (self.pc_range[3] - 1) / self.depth_num
        proj_points_clone = torch.einsum('bvnc, d -> bvndc', proj_points_clone, coords_d)
        proj_points_clone = torch.cat([proj_points_clone[..., :3], proj_points_clone.new_ones(*proj_points_clone.shape[:-1], 1)], dim=-1)
        projback_points = torch.einsum('bvndo, bvco -> bvndc', proj_points_clone, imgs2lidars)

        projback_points = (projback_points[..., :3] - projback_points.new_tensor(self.pc_range[:3])[None, None, None, :] )\
                        / (projback_points.new_tensor(self.pc_range[3:]) - projback_points.new_tensor(self.pc_range[:3]))[None, None, None, :]
        
        rv_embeds = self.rv_embedding(projback_points.reshape(*projback_points.shape[:-2], -1))
        rv_embeds = (rv_embeds * mask.unsqueeze(-1)).sum(dim=1)
        return rv_embeds

    def query_embed(self, ref_points, img_metas):
        ref_points = inverse_sigmoid(ref_points.clone()).sigmoid()
        bev_embeds = self._bev_query_embed(ref_points, img_metas)
        rv_embeds = self._rv_query_embed(ref_points, img_metas)
        return bev_embeds, rv_embeds

    def build_key_padding_mask(self, x, x_img, img_metas):
        if all([('valid_imgs' not in meta and 'valid_points' not in meta) for meta in img_metas]):
            return None

        # returns a boolen tensor of shape (batch_size, bev_len + rv_len)
        # True means valid, False means not valid
        B, _, H, W = x.shape
        BN, _, H_img, W_img = x_img.shape
        N = BN // B
        bev_len = H * W
        rv_len = N * H_img * W_img

        key_padding_mask = []
        for meta in img_metas:
            if 'valid_imgs' not in meta and 'valid_points' not in meta:
                key_padding_mask.append(torch.ones(bev_len + rv_len, dtype=torch.bool, device=x.device))
            else:
                if meta.get('valid_points', True):
                    bev_mask = torch.ones(bev_len, dtype=torch.bool, device=x.device)
                else:
                    # meta['valid_points'] == False
                    bev_mask = torch.zeros(bev_len, dtype=torch.bool, device=x.device)
                if 'valid_imgs' not in meta:
                    rv_mask = torch.ones(rv_len, dtype=torch.bool, device=x.device)
                else:
                    rv_mask = meta['valid_imgs'].repeat_interleave(H_img * W_img).to(x.device)
                key_padding_mask.append(torch.cat([bev_mask, rv_mask], dim=0))
        key_padding_mask = torch.stack(key_padding_mask, dim=0)
        return key_padding_mask

    def build_cls_token_attn_mask(self, attn_mask=None, num_queries=None):
        # the cls tokens can see all other tokens
        # while the object queries cannot see the cls tokens
        if attn_mask is None and num_queries is None:
            raise ValueError('`attn_mask` and `num_queries` cannot both be None.')

        if attn_mask is None:
            attn_mask = torch.zeros(num_queries, num_queries, dtype=torch.bool)

        attn_mask = torch.concat([torch.ones(attn_mask.shape[0], self.num_cls_tokens, dtype=torch.bool, device=attn_mask.device), attn_mask], dim=1)
        attn_mask = torch.concat([torch.zeros(self.num_cls_tokens, attn_mask.shape[1], dtype=torch.bool, device=attn_mask.device), attn_mask], dim=0)
        return attn_mask

    def forward_single(self, x, x_img, img_metas):
        """
            x: [bs c h w]
            return List(dict(head_name: [num_dec x bs x num_query * head_dim]) ) x task_num
        """
        ret_dicts = []
        x = self.shared_conv(x)
        
        reference_points = self.reference_points.weight
        reference_points, attn_mask, mask_dict = self.prepare_for_dn(x.shape[0], reference_points, img_metas)
        
        rv_pos_embeds = self._rv_pe(x_img, img_metas) # (B*N, ...)
        bev_pos_embeds = self.bev_embedding(pos2embed(self.coords_bev.to(x.device), num_pos_feats=self.hidden_dim))
        
        bev_query_embeds, rv_query_embeds = self.query_embed(reference_points, img_metas)
        query_embeds = bev_query_embeds + rv_query_embeds

        num_queries = query_embeds.shape[1]
        cls_tokens = self.cls_tokens.expand(x.shape[0], -1, -1)
        query_embeds = torch.concat([cls_tokens, query_embeds], dim=1)
        attn_mask = self.build_cls_token_attn_mask(attn_mask, num_queries)
        attn_mask = attn_mask.to(x.device)

        key_padding_mask = self.build_key_padding_mask(x, x_img, img_metas)

        outs_dec, _ = self.transformer(
                            x, x_img, query_embeds,
                            bev_pos_embeds, rv_pos_embeds,
                            attn_masks=attn_mask,
                            key_padding_mask=key_padding_mask,
                        )
        outs_dec = torch.nan_to_num(outs_dec) # [num_layers, bs, num_query, embed_dims]

        outs_cls_tokens = outs_dec[-1, :, :self.num_cls_tokens] # [bs, num_cls_tokens, embed_dims]

        logits = []
        for i in range(self.num_classes):
            cls_token = outs_cls_tokens[:, i] # [bs, embed_dims]
            head = self.head[i] # [bs, 1]
            logits.append(head(cls_token))
        logits = torch.stack(logits, dim=1) # [bs, num_cls_tokens, 1]
        logits = logits.squeeze(-1)
        ret_dicts.append({'logits': logits})

        return ret_dicts

    def forward(self, pts_feats, img_feats=None, img_metas=None):
        """
            list([bs, c, h, w])
        """
        img_metas = [img_metas for _ in range(len(pts_feats))]
        return multi_apply(self.forward_single, pts_feats, img_feats, img_metas)
        
    @force_fp32(apply_to=('preds_dicts'))
    def loss(self, cls_labels, preds_dicts, **kwargs):
        loss_dict = dict()
        logits = preds_dicts[0][0]['logits']
        labels = cls_labels['label'].to(logits.dtype)
        is_unimportant = cls_labels['is_unimportant']
        assert logits.shape == labels.shape, f'{logits.shape} != {labels.shape}'

        loss = self.loss_cls(logits, labels)
        if self.mask_unimportant_devices:
            loss_weight = 1 - is_unimportant.to(loss.dtype)
            loss = loss * loss_weight
        loss = loss.mean(dim=0)
        losses_cls = loss.unbind(0)
        for idx, loss_cls in enumerate(losses_cls):
            loss_dict[f'loss_cls_{idx}'] = loss_cls

        return loss_dict


@HEADS.register_module()
class CmtSceneClassifierHeadV2(BaseModule):

    def __init__(self,
                 in_channels,
                 num_classes,
                 pc_range,
                 num_query=900,
                 hidden_dim=128,
                 depth_num=64,
                 norm_bbox=True,
                 downsample_scale=8,
                 scalar=10,
                 noise_scale=1.0,
                 noise_trans=0.0,
                 dn_weight=1.0,
                 split=0.75,
                 train_cfg=None,
                 test_cfg=None,
                 transformer=None,
                 extractor=None,
                 num_feature_layer=1,
                 loss_cls=dict(
                     type="FocalLoss",
                     use_sigmoid=True,
                     reduction="mean",
                     gamma=2, alpha=0.25, loss_weight=1.0
                 ),
                 mask_unimportant_devices=False,
                 init_cfg=None,
                 **kwargs):
        assert init_cfg is None
        super(CmtSceneClassifierHeadV2, self).__init__(init_cfg=init_cfg)
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.num_query = num_query
        self.in_channels = in_channels
        self.depth_num = depth_num
        self.norm_bbox = norm_bbox
        self.downsample_scale = downsample_scale
        self.scalar = scalar
        self.bbox_noise_scale = noise_scale
        self.bbox_noise_trans = noise_trans
        self.dn_weight = dn_weight
        self.split = split
        self.pc_range = pc_range

        self.loss_cls = build_loss(loss_cls)
        self.mask_unimportant_devices = mask_unimportant_devices
        self.fp16_enabled = False
           
        self.shared_conv = ConvModule(
            in_channels,
            hidden_dim,
            kernel_size=3,
            padding=1,
            conv_cfg=dict(type="Conv2d"),
            norm_cfg=dict(type="BN2d")
        )
        
        # transformer
        self.transformer = build_transformer(transformer)
        self.reference_points = nn.Embedding(num_query, 3)
        self.bev_embedding = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.rv_embedding = nn.Sequential(
            nn.Linear(self.depth_num * 3, self.hidden_dim * 4),
            nn.ReLU(inplace=True),
            nn.Linear(self.hidden_dim * 4, self.hidden_dim)
        )

        # cls tokens
        self.num_cls_tokens = num_classes
        self.cls_tokens = nn.Parameter(torch.zeros(1, self.num_cls_tokens, hidden_dim))

        self.extractor = build_transformer_layer_sequence(extractor)
        self.num_feature_layer = num_feature_layer

        # head
        head = []
        for i in range(self.num_classes):
            _head = nn.Linear(self.hidden_dim, 1)
            head.append(_head)
        self.head = nn.ModuleList(head)

    def init_weights(self):
        super(CmtSceneClassifierHeadV2, self).init_weights()
        nn.init.uniform_(self.reference_points.weight.data, 0, 1)

        # initialization from deit
        trunc_normal_(self.cls_tokens, std=.02)
        for h in self.head:
            trunc_normal_(h.weight, std=.02)
            if h.bias is not None:
                nn.init.constant_(h.bias, 0)

    @property
    def coords_bev(self):
        cfg = self.train_cfg if self.train_cfg else self.test_cfg
        x_size, y_size = (
            cfg['grid_size'][1] // self.downsample_scale,
            cfg['grid_size'][0] // self.downsample_scale
        )
        meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
        batch_y, batch_x = torch.meshgrid(*[torch.linspace(it[0], it[1], it[2]) for it in meshgrid])
        batch_x = (batch_x + 0.5) / x_size
        batch_y = (batch_y + 0.5) / y_size
        coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)
        coord_base = coord_base.view(2, -1).transpose(1, 0) # (H*W, 2)
        return coord_base

    def prepare_for_dn(self, batch_size, reference_points, img_metas):
        padded_reference_points = reference_points.unsqueeze(0).repeat(batch_size, 1, 1)
        attn_mask = None
        mask_dict = None

        return padded_reference_points, attn_mask, mask_dict

    def _rv_pe(self, img_feats, img_metas):
        BN, C, H, W = img_feats.shape
        pad_h, pad_w, _ = img_metas[0]['pad_shape'][0]
        coords_h = torch.arange(H, device=img_feats[0].device).float() * pad_h / H
        coords_w = torch.arange(W, device=img_feats[0].device).float() * pad_w / W
        coords_d = 1 + torch.arange(self.depth_num, device=img_feats[0].device).float() * (self.pc_range[3] - 1) / self.depth_num
        coords_h, coords_w, coords_d = torch.meshgrid([coords_h, coords_w, coords_d])

        coords = torch.stack([coords_w, coords_h, coords_d, coords_h.new_ones(coords_h.shape)], dim=-1)
        coords[..., :2] = coords[..., :2] * coords[..., 2:3]
        
        imgs2lidars = np.concatenate([np.linalg.inv(meta['lidar2img']) for meta in img_metas])
        imgs2lidars = torch.from_numpy(imgs2lidars).float().to(coords.device)
        coords_3d = torch.einsum('hwdo, bco -> bhwdc', coords, imgs2lidars)
        coords_3d = (coords_3d[..., :3] - coords_3d.new_tensor(self.pc_range[:3])[None, None, None, :] )\
                        / (coords_3d.new_tensor(self.pc_range[3:]) - coords_3d.new_tensor(self.pc_range[:3]))[None, None, None, :]
        return self.rv_embedding(coords_3d.reshape(*coords_3d.shape[:-2], -1))

    def _bev_query_embed(self, ref_points, img_metas):
        bev_embeds = self.bev_embedding(pos2embed(ref_points, num_pos_feats=self.hidden_dim))
        return bev_embeds

    def _rv_query_embed(self, ref_points, img_metas):
        pad_h, pad_w, _ = img_metas[0]['pad_shape'][0]
        lidars2imgs = np.stack([meta['lidar2img'] for meta in img_metas])
        lidars2imgs = torch.from_numpy(lidars2imgs).float().to(ref_points.device)
        imgs2lidars = np.stack([np.linalg.inv(meta['lidar2img']) for meta in img_metas])
        imgs2lidars = torch.from_numpy(imgs2lidars).float().to(ref_points.device)

        ref_points = ref_points * (ref_points.new_tensor(self.pc_range[3:]) - ref_points.new_tensor(self.pc_range[:3])) + ref_points.new_tensor(self.pc_range[:3])
        proj_points = torch.einsum('bnd, bvcd -> bvnc', torch.cat([ref_points, ref_points.new_ones(*ref_points.shape[:-1], 1)], dim=-1), lidars2imgs)
        
        proj_points_clone = proj_points.clone()
        z_mask = proj_points_clone[..., 2:3].detach() > 0
        proj_points_clone[..., :3] = proj_points[..., :3] / (proj_points[..., 2:3].detach() + z_mask * 1e-6 - (~z_mask) * 1e-6) 
        # proj_points_clone[..., 2] = proj_points.new_ones(proj_points[..., 2].shape) 
        
        mask = (proj_points_clone[..., 0] < pad_w) & (proj_points_clone[..., 0] >= 0) & (proj_points_clone[..., 1] < pad_h) & (proj_points_clone[..., 1] >= 0)
        mask &= z_mask.squeeze(-1)

        coords_d = 1 + torch.arange(self.depth_num, device=ref_points.device).float() * (self.pc_range[3] - 1) / self.depth_num
        proj_points_clone = torch.einsum('bvnc, d -> bvndc', proj_points_clone, coords_d)
        proj_points_clone = torch.cat([proj_points_clone[..., :3], proj_points_clone.new_ones(*proj_points_clone.shape[:-1], 1)], dim=-1)
        projback_points = torch.einsum('bvndo, bvco -> bvndc', proj_points_clone, imgs2lidars)

        projback_points = (projback_points[..., :3] - projback_points.new_tensor(self.pc_range[:3])[None, None, None, :] )\
                        / (projback_points.new_tensor(self.pc_range[3:]) - projback_points.new_tensor(self.pc_range[:3]))[None, None, None, :]
        
        rv_embeds = self.rv_embedding(projback_points.reshape(*projback_points.shape[:-2], -1))
        rv_embeds = (rv_embeds * mask.unsqueeze(-1)).sum(dim=1)
        return rv_embeds

    def query_embed(self, ref_points, img_metas):
        ref_points = inverse_sigmoid(ref_points.clone()).sigmoid()
        bev_embeds = self._bev_query_embed(ref_points, img_metas)
        rv_embeds = self._rv_query_embed(ref_points, img_metas)
        return bev_embeds, rv_embeds

    def build_key_padding_mask(self, x, x_img, img_metas):
        if all([('valid_imgs' not in meta and 'valid_points' not in meta) for meta in img_metas]):
            return None

        # returns a boolen tensor of shape (batch_size, bev_len + rv_len)
        # True means valid, False means not valid
        B, _, H, W = x.shape
        BN, _, H_img, W_img = x_img.shape
        N = BN // B
        bev_len = H * W
        rv_len = N * H_img * W_img

        key_padding_mask = []
        for meta in img_metas:
            if 'valid_imgs' not in meta and 'valid_points' not in meta:
                key_padding_mask.append(torch.ones(bev_len + rv_len, dtype=torch.bool, device=x.device))
            else:
                if meta.get('valid_points', True):
                    bev_mask = torch.ones(bev_len, dtype=torch.bool, device=x.device)
                else:
                    # meta['valid_points'] == False
                    bev_mask = torch.zeros(bev_len, dtype=torch.bool, device=x.device)
                if 'valid_imgs' not in meta:
                    rv_mask = torch.ones(rv_len, dtype=torch.bool, device=x.device)
                else:
                    rv_mask = meta['valid_imgs'].repeat_interleave(H_img * W_img).to(x.device)
                key_padding_mask.append(torch.cat([bev_mask, rv_mask], dim=0))
        key_padding_mask = torch.stack(key_padding_mask, dim=0)
        return key_padding_mask

    def build_cls_token_attn_mask(self, attn_mask=None, num_queries=None):
        # the cls tokens can see all other tokens
        # while the object queries cannot see the cls tokens
        if attn_mask is None and num_queries is None:
            raise ValueError('`attn_mask` and `num_queries` cannot both be None.')

        if attn_mask is None:
            attn_mask = torch.zeros(num_queries, num_queries, dtype=torch.bool)

        attn_mask = torch.concat([torch.ones(attn_mask.shape[0], self.num_cls_tokens, dtype=torch.bool, device=attn_mask.device), attn_mask], dim=1)
        attn_mask = torch.concat([torch.zeros(self.num_cls_tokens, attn_mask.shape[1], dtype=torch.bool, device=attn_mask.device), attn_mask], dim=0)
        return attn_mask

    def forward_single(self, x, x_img, img_metas):
        """
            x: [bs c h w]
            return List(dict(head_name: [num_dec x bs x num_query * head_dim]) ) x task_num
        """
        ret_dicts = []
        x = self.shared_conv(x)
        
        reference_points = self.reference_points.weight
        reference_points, attn_mask, mask_dict = self.prepare_for_dn(x.shape[0], reference_points, img_metas)
        
        rv_pos_embeds = self._rv_pe(x_img, img_metas) # (B*N, ...)
        bev_pos_embeds = self.bev_embedding(pos2embed(self.coords_bev.to(x.device), num_pos_feats=self.hidden_dim))
        
        bev_query_embeds, rv_query_embeds = self.query_embed(reference_points, img_metas)
        query_embeds = bev_query_embeds + rv_query_embeds

        # num_queries = query_embeds.shape[1]
        # cls_tokens = self.cls_tokens.expand(x.shape[0], -1, -1)
        # query_embeds = torch.concat([cls_tokens, query_embeds], dim=1)
        # attn_mask = self.build_cls_token_attn_mask(attn_mask, num_queries)
        # attn_mask = attn_mask.to(x.device)

        key_padding_mask = self.build_key_padding_mask(x, x_img, img_metas)

        outs_dec, _ = self.transformer(
                            x, x_img, query_embeds,
                            bev_pos_embeds, rv_pos_embeds,
                            attn_masks=attn_mask,
                            key_padding_mask=key_padding_mask,
                        )
        outs_dec = torch.nan_to_num(outs_dec) # [num_layers, bs, num_query, embed_dims]

        cls_tokens = self.cls_tokens.expand(x.shape[0], -1, -1).transpose(0, 1)
        extractor_kv = outs_dec[-self.num_feature_layer:] # [num_layers, bs, num_kv, embed_dims]
        extractor_kv = extractor_kv.transpose(1, 2) # [num_layers, num_kv, bs, embed_dims]
        extractor_kv = extractor_kv.flatten(0, 1) # [num_layers * num_kv, bs, embed_dims]
        outs_cls_tokens = self.extractor(cls_tokens, extractor_kv, extractor_kv)
        outs_cls_tokens = outs_cls_tokens.transpose(0, 1)

        # outs_cls_tokens = outs_dec[-1, :, :self.num_cls_tokens] # [bs, num_cls_tokens, embed_dims]

        logits = []
        for i in range(self.num_classes):
            cls_token = outs_cls_tokens[:, i] # [bs, embed_dims]
            head = self.head[i] # [bs, 1]
            logits.append(head(cls_token))
        logits = torch.stack(logits, dim=1) # [bs, num_cls_tokens, 1]
        logits = logits.squeeze(-1)
        ret_dicts.append({'logits': logits})

        return ret_dicts

    def forward(self, pts_feats, img_feats=None, img_metas=None):
        """
            list([bs, c, h, w])
        """
        img_metas = [img_metas for _ in range(len(pts_feats))]
        return multi_apply(self.forward_single, pts_feats, img_feats, img_metas)
        
    @force_fp32(apply_to=('preds_dicts'))
    def loss(self, cls_labels, preds_dicts, **kwargs):
        loss_dict = dict()
        logits = preds_dicts[0][0]['logits']
        labels = cls_labels['label'].to(logits.dtype)
        is_unimportant = cls_labels['is_unimportant']
        assert logits.shape == labels.shape, f'{logits.shape} != {labels.shape}'

        loss = self.loss_cls(logits, labels)
        if self.mask_unimportant_devices:
            loss_weight = 1 - is_unimportant.to(loss.dtype)
            loss = loss * loss_weight
        loss = loss.mean(dim=0)
        losses_cls = loss.unbind(0)
        for idx, loss_cls in enumerate(losses_cls):
            loss_dict[f'loss_cls_{idx}'] = loss_cls

        return loss_dict


@LOSSES.register_module()
class SigmoidFocalLoss(nn.Module):
    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        reduction: str = 'none',
        loss_weight: float = 1.0,
    ) -> None:
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        return self.loss_weight * sigmoid_focal_loss(
            inputs=inputs,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )


@LOSSES.register_module()
class BCEWithLogitsLoss(nn.Module):
    def __init__(self,
        weight: Optional[torch.Tensor] = None,
        pos_weight: Optional[torch.Tensor] = None,
        reduction: str = 'none',
        loss_weight: float = 1.0,
    )-> None:
        super().__init__()
        self.register_buffer('weight', weight)
        self.register_buffer('pos_weight', pos_weight)
        self.weight: Optional[torch.Tensor]
        self.pos_weight: Optional[torch.Tensor]
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return self.loss_weight * F.binary_cross_entropy_with_logits(
            input,
            target,
            self.weight,
            pos_weight=self.pos_weight,
            reduction=self.reduction
        )
