import math
import warnings
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmcv.cnn import xavier_init
from mmdet.core import multi_apply, reduce_mean
from mmdet.models import HEADS
from mmdet.models.dense_heads import DETRHead
from mmdet.models.builder import build_loss
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes

from einops.layers.torch import Rearrange
from einops import rearrange, repeat

from .bbox.utils import normalize_bbox, encode_bbox
from .utils import VERSION
from .world_model_hd import RSSM
from .losses import WMLoss


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class FlowAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, x, pos=None):
        if self.channel_fuse:
            # x = rearrange(x, 'b n (h d) -> b n h d', h = self.heads)

            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))
            
            # q, k, v = map(lambda t: rearrange(t, 'b n h d -> b n h d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)
            # print('\nself fuse', q.shape, k.shape, dots.shape)

            out = torch.matmul(attn, v)
            out = self.to_out(out)
            
            if not self.with_res:
                return out  # .flatten(2,3)
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)  # .flatten(2,3)
        else:
            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\nself not fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = self.to_out(rearrange(out, 'b h n d -> b n (h d)'))
            
            if not self.with_res:
                return out
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)

class FlowCrossAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, q_ini, kv_ini, q_pos=None, k_pos=None, flatten=True):
        if self.channel_fuse:
            # q_init = rearrange(q_init, 'b n (h d) -> b n h d', h = self.heads)
            # kv_init = rearrange(kv_init, 'b n (h d) -> b n h d', h = self.heads)

            if q_pos is not None and k_pos is not None:
                # q_pos = rearrange(q_pos, 'b n (h d) -> b n h d', h = self.heads)
                # k_pos = rearrange(k_pos, 'b n (h d) -> b n h d', h = self.heads)
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\ncross fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n h d')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                out = self.res_norm(out)
            return out.flatten(2,3) if flatten else out
        else:
            if q_pos is not None and k_pos is not None:
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                return self.res_norm(out)
            return out


@HEADS.register_module()
class FlowADHead(DETRHead):
    def __init__(self,
                 *args,
                 num_classes,
                 in_channels,
                 query_denoising=True,
                 query_denoising_groups=10,
                 bbox_coder=None,
                 code_size=10,
                 code_weights=[1.0] * 10,
                 train_cfg=dict(),
                 test_cfg=dict(max_per_img=100),
                 flow_patches=[8,4,2,1],
                 flow_ids=[0,1,2,3],
                 flow_attn_ids=[2,3],
                 flow_wm_ids=[2,3],
                 flow_pad_single=1,
                 flow_num=22,
                 spat_size=6,
                 loss_wm=dict(
                    type='WMLoss',
                    loss_weight=1.0,
                ),
                loss_wm_spatial=dict(
                    type='WMLoss',
                    loss_weight=0.25,
                ),
                 **kwargs):
        self.code_size = code_size
        self.code_weights = code_weights
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.fp16_enabled = False
        self.embed_dims = in_channels

        self.flow_num = flow_num
        self.spat_size = spat_size

        super(FlowADHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)

        self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False)
        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.pc_range = self.bbox_coder.pc_range

        self.dn_enabled = query_denoising
        self.dn_group_num = query_denoising_groups
        self.dn_weight = 1.0
        self.dn_bbox_noise_scale = 0.5
        self.dn_label_noise_scale = 0.5

        self.flow_patches = flow_patches
        self.flow_ids = flow_ids
        self.flow_attn_ids = flow_attn_ids
        self.flow_wm_ids = flow_wm_ids
        self.flow_pad_single = flow_pad_single
        self.flow_pad_sum = 1 + 2 * flow_pad_single
        self.ids_pad_left_right = torch.tensor([[5,1],[0,2],[1,3],[2,4],[3,5],[4,0]], dtype=torch.long, requires_grad=False)
        self.frame_num = 8

        self.flow_mlps = nn.ModuleList([
            nn.Linear(self.flow_pad_sum*self.flow_patches[flow_id], self.flow_patches[flow_id]) for flow_id in self.flow_ids
        ])
        self.flow_attns = nn.ModuleList([
            FlowAttention(dim=in_channels, heads=4, dim_head=in_channels//4, dropout=0.) for flow_attn_id in self.flow_attn_ids
        ])

        self.world_models_temporal = nn.ModuleList([
            RSSM(embedding_dim=in_channels) for flow_wm_id in self.flow_wm_ids
        ])
        self.world_models_spatial = nn.ModuleList([
            RSSM(embedding_dim=in_channels) for flow_wm_id in self.flow_wm_ids
        ])
        self.patch_selfattn = FlowAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_local = FlowCrossAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)

        self.patch_crossattn_temporal = FlowCrossAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_spatial = FlowCrossAttention(dim=in_channels, heads=self.frame_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)

        self.flow_wm_convs = nn.ModuleList([
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])
        self.flow_wm_convs_spatial = nn.ModuleList([
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])

        self.loss_wm_temporal = build_loss(loss_wm)
        self.loss_wm_spatial = build_loss(loss_wm_spatial)

    def _init_layers(self):
        self.init_query_bbox = nn.Embedding(self.num_query, 10)  # (x, y, z, w, l, h, sin, cos, vx, vy)
        self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1)  # DAB-DETR

        nn.init.zeros_(self.init_query_bbox.weight[:, 2:3])
        nn.init.zeros_(self.init_query_bbox.weight[:, 8:10])
        nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5)

        grid_size = int(math.sqrt(self.num_query))
        assert grid_size * grid_size == self.num_query
        x = y = torch.arange(grid_size)
        xx, yy = torch.meshgrid(x, y, indexing='ij')  # [0, grid_size - 1]
        xy = torch.cat([xx[..., None], yy[..., None]], dim=-1)
        xy = (xy + 0.5) / grid_size  # [0.5, grid_size - 0.5] / grid_size ~= (0, 1)
        with torch.no_grad():
            self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2)  # [Q, 2]
        
        self.can_bus_mlp = nn.Sequential(
            nn.Linear(18, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)

        self.flow_query_xys = self.generate_flow_xys()
        self.flow_query_embed = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_query_embed, distribution='uniform', bias=0.)

        self.flow_spatial_pe_map = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_spatial_pe_map, distribution='uniform', bias=0.)

        self.flow_temporal_pe_map = nn.Sequential(
            nn.Linear(18, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_temporal_pe_map, distribution='uniform', bias=0.)

    def init_weights(self):
        self.transformer.init_weights()
    
    def generate_flow_xys(self,):
        pc_range = 51.2
        tan_30 = math.tan(math.radians(30))
        flow_items = self.flow_num
        p_x = pc_range * tan_30

        f_range_x = [-p_x, p_x]
        f_range_y = [pc_range, pc_range]

        fr_range_x = [p_x, pc_range]
        fr_range_y = [pc_range, 0]

        br_range_x = [pc_range, p_x]
        br_range_y = [0, -pc_range]

        b_range_x = [-p_x, p_x]
        b_range_y = [-pc_range, -pc_range]

        bl_range_x = [-p_x, -pc_range]
        bl_range_y = [-pc_range, 0]

        fl_range_x = [-pc_range, -p_x]
        fl_range_y = [0, pc_range]

        flow_xys_all = []
        for range_x, range_y in zip([f_range_x, fr_range_x, br_range_x, b_range_x, bl_range_x, fl_range_x],
                                    [f_range_y, fr_range_y, br_range_y, b_range_y, bl_range_y, fl_range_y]):
            step_x = (range_x[1]-range_x[0])/flow_items
            step_y = (range_y[1]-range_y[0])/flow_items
            xys = torch.stack([
                torch.arange(range_x[0], range_x[1], step_x)+step_x/2 if step_x != 0 else torch.ones(flow_items)*range_x[0],
                torch.arange(range_y[0], range_y[1], step_y)+step_y/2 if step_y != 0 else torch.ones(flow_items)*range_y[0],
                                ], dim=-1)
            flow_xys_all.append(xys)

        flow_xys_all = torch.stack(flow_xys_all, dim=0)
        return flow_xys_all
    
    def get_forward_point(self, sdc_trajs):
        ego_xys = [one[1] for one in sdc_trajs]
        forward_points = []
        # tan_30 = math.tan(math.radians(30))
        tan_60 = math.tan(math.radians(60))
        half = 51.2 / tan_60
        step = half / self.flow_num * 2
        for ego_xy in ego_xys:
            if sum(ego_xy) == 0:
                forward_points.append(11)
                continue
            if ego_xy[0] >= 0:
                left_right = 1
            else:
                left_right = 0
            tan = abs(ego_xy[1]/ego_xy[0])
            if tan >= tan_60:
                forward_point = (51.2/tan + half)//step
            else:
                if left_right:
                    x = 51.2 * tan_60 / (tan + tan_60)
                    y = x * tan
                    delta = math.sqrt((x-51.2/tan_60)**2 + (y-51.2)**2)
                    forward_point = 22 + delta//step
                else:
                    x = 51.2 * tan_60 / (-tan - tan_60)
                    y = - x * tan
                    delta = math.sqrt((x+51.2/tan_60)**2 + (y-51.2)**2)
                    forward_point = - delta//step
            forward_points.append(int(forward_point))
        return forward_points[0]

    def forward(self, mlvl_feats, img_metas):
        # cam_types = [
        #         'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT',
        #         'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT'
        #     ]
        can_bus = mlvl_feats[0].new_tensor(
            [each['ego_pose'][0] for each in img_metas])  # [:, :]
        can_bus = self.can_bus_mlp(can_bus)[:, None, :]

        # print([one.shape for one in mlvl_feats])

        device = mlvl_feats[0].device
        B, TN, C, _, _ = mlvl_feats[0].shape
        N = 6
        T = TN // N
        flow_spatial_pe = self.flow_spatial_pe_map(self.flow_query_xys.to(device)).unsqueeze(0)  # [1, N, 22, C]
        flow_spatial_pe_local = rearrange(flow_spatial_pe.repeat(B*T, 1, 1, 1), 'b n y c -> (b n) y c')  # [BTN, 22, C]

        # pad spatial PE for 3*patch self-attn
        flow_spatial_pe_3p = flow_spatial_pe_local.transpose(1,2).reshape(B, T, N, C, 1, self.flow_num)
        pe_pad_left = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single:]
        pe_pad_right = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single]
        # [B, T, N, C, 1, 24]
        flow_spatial_pe_3p = torch.cat([pe_pad_left, flow_spatial_pe_3p, pe_pad_right], dim=-1)
        # [BTN, C*1*3, 1*22]
        flow_spatial_pe_3p_unfold = F.unfold(flow_spatial_pe_3p.flatten(0,2), kernel_size=(1, self.flow_pad_sum), stride=1)
        # [BTN, 22, 3, C]
        flow_spatial_pe_3p = rearrange(flow_spatial_pe_3p_unfold, 'b (c l) y -> b y l c', c=C, l=self.flow_pad_sum)

        ego_pose_temporal = torch.from_numpy(
            np.stack([each['ego_pose'] for each in img_metas], axis=0)).to(device).to(flow_spatial_pe.dtype)
        flow_temporal_pe = self.flow_temporal_pe_map(ego_pose_temporal)  # [B, T, C]
        flow_temporal_pe_local = flow_temporal_pe.reshape(B, T, 1, 1, C).repeat(1, 1, N, self.flow_num, 1).flatten(0,2)  # [BTN, 22, C]

        outs = {}
        outputs_wm_temporal = []
        outputs_wm_spatial = []

        mlvl_feats_flow = []
        mlvl_feats_flow_unfold = []
        mlvl_feats_flow_unfold_fuse = []

        H_W_Ps = []
        for idx, mlvl_feat in enumerate(mlvl_feats):
            if idx not in self.flow_ids:
                mlvl_feats_flow.append(mlvl_feat)
                continue
            flow_patch = self.flow_patches[idx]
            B, TN, C, H, W = mlvl_feat.shape
            H_W_Ps.append([H,W,flow_patch])
            mlvl_feat_reshaped = mlvl_feat.reshape(B, T, N, C, H, W)

            # pad two sides of each view
            feat_pad_left = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single*flow_patch:]
            feat_pad_right = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single*flow_patch]
            mlvl_feat_pad_reshaped = torch.cat([feat_pad_left, mlvl_feat_reshaped, feat_pad_right], dim=-1)

            # [BTN, C*H*3P, 1*22]
            mlvl_feat_pad_unfold = F.unfold(mlvl_feat_pad_reshaped.flatten(0,2), kernel_size=(H, self.flow_pad_sum*flow_patch), stride=flow_patch)
            # [BTN, C, H, 3P, 22]
            mlvl_feat_flow_unfold = mlvl_feat_pad_unfold.reshape(B*TN, C, H, self.flow_pad_sum*flow_patch, W//flow_patch)

            if idx in self.flow_attn_ids:
                # interaction inside 3 patches
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, 'b c h p y -> (b h y) p c')
                # [BTN, 22, 3, C] -> [BTN * H * 22, 3P, C]
                flow_spatial_pe_3p_selfattn = flow_spatial_pe_3p.unsqueeze(3).unsqueeze(1).repeat(1,H,1,1,flow_patch,1).flatten(0,2).flatten(1,2)
                # [BTN * H * 22, 3P, C]
                mlvl_feat_flow_unfold = self.flow_attns[self.flow_attn_ids.index(idx)](mlvl_feat_flow_unfold, pos=flow_spatial_pe_3p_selfattn)
                # [BTN, C, H, 3P, 22]
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, '(b h y) p c -> b c h p y', b=B*TN, h=H, y=W//flow_patch)

            # [BTN, C, H, 22, 3P] -> [BTN, C, H, 22, P]
            mlvl_feat_flow_unfold = self.flow_mlps[self.flow_ids.index(idx)](mlvl_feat_flow_unfold.transpose(3,4))
            mlvl_feats_flow_unfold.append(mlvl_feat_flow_unfold)
            # [BTN, C, 1, 22, patch]
            mlvl_feats_flow_unfold_fuse.append(mlvl_feat_flow_unfold.mean(dim=2))
        
        # [BTN, C, 22, 8+4+2+1]
        mlvl_feats_flow_unfold_fuse = torch.cat(mlvl_feats_flow_unfold_fuse, dim=-1)
        # [BTN, 22, C] -> [BTN, 1, 22, C]
        flow_spatial_pe_selfattn = flow_spatial_pe_local.unsqueeze(1)
        # [BTN, 8+4+2+1, 22, C]
        mlvl_feats_flow_unfold_fuse = self.patch_selfattn(rearrange(mlvl_feats_flow_unfold_fuse, 'b c y p -> b p y c'), pos=flow_spatial_pe_selfattn)

        # [N, 22, C] -> [BTN, 22, C]
        flow_query = self.flow_query_embed(self.flow_query_xys.to(mlvl_feat.device)).unsqueeze(0).repeat(B*T, 1, 1, 1).flatten(0,1)
        # get local messages from [8+4+2+1] of [BTN, 8+4+2+1, 22, C]
        flow_query = self.patch_crossattn_local(flow_query.unsqueeze(1), mlvl_feats_flow_unfold_fuse).reshape(B*TN, W//flow_patch, C)
        # print('output', flow_query.shape)

        # get temporal messages from [T] of [BTN, 8+4+2+1, 22, C]
        # prompt for temporal world model
        flow_temporal_pe_crossattn = rearrange(flow_temporal_pe_local, '(b t n) y c -> (b n) t y c', b=B, t=T, n=6)
        # [BN, T, 22, C]
        flow_query_temporal = self.patch_crossattn_temporal(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_temporal_pe_crossattn,
            k_pos=flow_temporal_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_temporal.shape)

        # get spatial messages from [22] of [BTN, 8+4+2+1, 22, C]
        # prompt for spatial world model
        flow_spatial_pe_crossattn = rearrange(flow_spatial_pe_local, '(b t n) y c -> (b n) y t c', b=B, t=T, n=6)
        # [BN, 22, T, C]
        flow_query_spatial = self.patch_crossattn_spatial(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_spatial_pe_crossattn,
            k_pos=flow_spatial_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_spatial.shape)

        # split left/right spatial flow from ego forward_point
        sdc_trajs = [each['sdc_traj'] for each in img_metas]
        forward_point = self.get_forward_point(sdc_trajs)
        flow_query_spatial = rearrange(flow_query_spatial, '(b n) y t c -> (b t n) y c', b=B, t=T, n=6)
        flow_query_spat = rearrange(flow_query_spatial, '(b t n) y c -> b (n y) t c', b=B, t=T, n=6)
        if forward_point >= 0:
            flow_query_spat_l = torch.cat([flow_query_spat[:, forward_point+3*self.flow_num:], flow_query_spat[:, :forward_point]], dim=1).flip(dims=[1])
            flow_query_spat_r = flow_query_spat[:, forward_point:forward_point+3*self.flow_num]
        else:
            flow_query_spat_r = torch.cat([flow_query_spat[:, forward_point:], flow_query_spat[:, :forward_point+3*self.flow_num:]], dim=1)
            flow_query_spat_l = flow_query_spat[:, forward_point+3*self.flow_num:forward_point].flip(dims=[1])
        # [B, 2, 66, T, C]
        flow_query_spat = torch.stack([flow_query_spat_l, flow_query_spat_r], dim=1)
        flow_query_spat = rearrange(flow_query_spat, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)

        for idx, (mlvl_feat_flow_unfold, (H, W, flow_patch)) in enumerate(zip(mlvl_feats_flow_unfold, H_W_Ps)):
            if idx in self.flow_wm_ids:
                layer_idx = self.flow_wm_ids.index(idx)

                # temporal world model
                # [BTN, C, H, W/patch, patch]
                output_wm, mlvl_feat_flow_unfold_fut = self.world_models_temporal[layer_idx](
                    rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> (b n) t y c', b=B, t=T, n=6).flip(dims=[1]),
                    flow_query_temporal.flip(dims=[1]))
                outputs_wm_temporal.append(output_wm)
                
                # enhance temporal messages with temporal prediction
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        rearrange(mlvl_feat_flow_unfold_fut.flip(dims=[1]), '(b n) t y c -> (b t n) y c', b=B, n=6, y=W//flow_patch).unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs[layer_idx](mlvl_feat_flow_fold)

                mlvl_feat_flow_unfold = F.unfold(mlvl_feat_flow_fold, kernel_size=(H, flow_patch), stride=flow_patch).reshape(B*TN, C, H, flow_patch, W//flow_patch).transpose(3,4)  # [btn, chp, y]
                
                # split left/right spatial flow from ego forward_point
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> b (n y) t c', b=B, t=T, n=6)
                # mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                # mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                if forward_point >= 0:
                    mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                    mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                else:
                    mlvl_feat_flow_unfold_spatial_r = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point:], mlvl_feat_flow_unfold_spatial[:, :forward_point+3*self.flow_num]], dim=1)
                    mlvl_feat_flow_unfold_spatial_l = mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:forward_point].flip(dims=[1])
                # [B, 2, 66, T, C]
                mlvl_feat_flow_unfold_spatial = torch.stack([mlvl_feat_flow_unfold_spatial_l, mlvl_feat_flow_unfold_spatial_r], dim=1)
                # split with spatial patch (len 6)
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold_spatial, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)
                
                # spatial world model
                # [BTN, C, H, W/patch, patch]
                output_wm_spatial, mlvl_feat_flow_unfold_fut_spatial = self.world_models_spatial[layer_idx](
                    mlvl_feat_flow_unfold_spatial,
                    flow_query_spat)
                outputs_wm_spatial.append(output_wm_spatial)

                # enhance spatial messages with spatial prediction
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, '(b n p) y t c -> b n (y p) t c', b=B, n=2, p=self.spat_size)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,1], mlvl_feat_flow_unfold_fut_spatial[:,0].flip(dims=[1])], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,-forward_point:], mlvl_feat_flow_unfold_fut_spatial[:,:-forward_point]], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, 'b (n y) t c -> (b t n) y c', n=6, y=self.flow_num)
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        mlvl_feat_flow_unfold_fut_spatial.unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch*3, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs_spatial[layer_idx](mlvl_feat_flow_fold)
            else:
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
            mlvl_feats_flow.append(mlvl_feat_flow_fold.reshape(B, TN, C, H, W))


        query_bbox = self.init_query_bbox.weight.clone()  # [Q, 10]
        #query_bbox[..., :3] = query_bbox[..., :3].sigmoid()

        # query denoising
        B = mlvl_feats[0].shape[0]
        query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)

        # can_bus = query_feat.new_tensor(
        #     [each['ego_pose'][0] for each in img_metas])  # [:, :]
        # can_bus = self.can_bus_mlp(can_bus)[:, None, :]
        query_feat = query_feat + can_bus

        cls_scores, bbox_preds = self.transformer(
            query_bbox,
            query_feat,
            mlvl_feats_flow,  # mlvl_feats
            attn_mask=attn_mask,
            img_metas=img_metas,
        )

        bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
        bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
        bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]

        bbox_preds = torch.cat([
            bbox_preds[..., 0:2],
            bbox_preds[..., 3:5],
            bbox_preds[..., 2:3],
            bbox_preds[..., 5:10],
        ], dim=-1)  # [cx, cy, w, l, cz, h, sin, cos, vx, vy]

        if mask_dict is not None and mask_dict['pad_size'] > 0:  # if using query denoising
            output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
            output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
            output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
            output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :]
            mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds)
            outs.update({
                'all_cls_scores': output_cls_scores,
                'all_bbox_preds': output_bbox_preds,
                'enc_cls_scores': None,
                'enc_bbox_preds': None, 
                'dn_mask_dict': mask_dict,
            })
        else:
            outs.update({
                'all_cls_scores': cls_scores,
                'all_bbox_preds': bbox_preds,
                'enc_cls_scores': None,
                'enc_bbox_preds': None, 
            })
        outs['wm_temporal'] = outputs_wm_temporal
        outs['wm_spatial'] = outputs_wm_spatial

        return outs

    def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
        # mostly borrowed from:
        #  - https://github.com/IDEA-Research/DN-DETR/blob/main/models/DN_DAB_DETR/dn_components.py
        #  - https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models/dense_heads/petrv2_dnhead.py

        device = init_query_bbox.device
        indicator0 = torch.zeros([self.num_query, 1], device=device)
        init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
        init_query_feat = torch.cat([init_query_feat, indicator0], dim=1)

        if self.training and self.dn_enabled:
            targets = [{
                'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center,
                                     m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(),
                'labels': m['gt_labels_3d'].cuda().long()
            } for m in img_metas]

            known = [torch.ones_like(t['labels'], device=device) for t in targets]
            known_num = [sum(k) for k in known]

            # can be modified to selectively denosie some label or boxes; also known label prediction
            unmask_bbox = unmask_label = torch.cat(known)
            labels = torch.cat([t['labels'] for t in targets]).clone()
            bboxes = torch.cat([t['bboxes'] for t in targets]).clone()
            batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])

            known_indice = torch.nonzero(unmask_label + unmask_bbox)
            known_indice = known_indice.view(-1)

            # add noise
            known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1)
            known_labels = labels.repeat(self.dn_group_num, 1).view(-1)
            known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1)
            known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9
            known_labels_expand = known_labels.clone()
            known_bbox_expand = known_bboxs.clone()

            # noise on the box
            if self.dn_bbox_noise_scale > 0:
                wlh = known_bbox_expand[..., 3:6].clone()
                rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0
                known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale
                # known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale
                # known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale

            known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range)
            known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0)
            # nn.init.constant(known_bbox_expand[..., 8:10], 0.0)

            # noise on the label
            if self.dn_label_noise_scale > 0:
                p = torch.rand_like(known_labels_expand.float())
                chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1)  # usually half of bbox noise
                new_label = torch.randint_like(chosen_indice, 0, self.num_classes)  # randomly put a new one here
                known_labels_expand.scatter_(0, chosen_indice, new_label)

            known_feat_expand = label_enc(known_labels_expand)
            indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device)  # add dn part indicator
            known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1)

            # construct final query
            dn_single_pad = int(max(known_num))
            dn_pad_size = int(dn_single_pad * self.dn_group_num)
            dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device)
            dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device)
            input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1)
            input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1)

            if len(known_num):
                map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]
                map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long()

            if len(known_bid):
                input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand
                input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand

            total_size = dn_pad_size + self.num_query
            attn_mask = torch.ones([total_size, total_size], device=device) < 0

            # match query cannot see the reconstruct
            attn_mask[dn_pad_size:, :dn_pad_size] = True
            for i in range(self.dn_group_num):
                if i == 0:
                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
                if i == self.dn_group_num - 1:
                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
                else:
                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True

            mask_dict = {
                'known_indice': torch.as_tensor(known_indice).long(),
                'batch_idx': torch.as_tensor(batch_idx).long(),
                'map_known_indice': torch.as_tensor(map_known_indice).long(),
                'known_lbs_bboxes': (known_labels, known_bboxs),
                'pad_size': dn_pad_size
            }
        else:
            input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1)
            input_query_feat = init_query_feat.repeat(batch_size, 1, 1)
            attn_mask = None
            mask_dict = None

        return input_query_bbox, input_query_feat, attn_mask, mask_dict

    def prepare_for_dn_loss(self, mask_dict):
        cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes']
        known_labels, known_bboxs = mask_dict['known_lbs_bboxes']
        map_known_indice = mask_dict['map_known_indice'].long()
        known_indice = mask_dict['known_indice'].long()
        batch_idx = mask_dict['batch_idx'].long()
        bid = batch_idx[known_indice]
        num_tgt = known_indice.numel()

        if len(cls_scores) > 0:
            cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
            bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)

        return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt

    def dn_loss_single(self,
                       cls_scores,
                       bbox_preds,
                       known_bboxs,
                       known_labels,
                       num_total_pos=None):        
        # Compute the average number of gt boxes accross all gpus
        num_total_pos = cls_scores.new_tensor([num_total_pos])
        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item()

        # cls loss
        cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
        bbox_weights = torch.ones_like(bbox_preds)
        label_weights = torch.ones_like(known_labels)
        loss_cls = self.loss_cls(
            cls_scores,
            known_labels.long(),
            label_weights,
            avg_factor=num_total_pos
        )

        # regression L1 loss
        bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
        normalized_bbox_targets = normalize_bbox(known_bboxs)
        isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
        bbox_weights = bbox_weights * self.code_weights
        loss_bbox = self.loss_bbox(
            bbox_preds[isnotnan, :10],
            normalized_bbox_targets[isnotnan, :10],
            bbox_weights[isnotnan, :10],
            avg_factor=num_total_pos
        )

        loss_cls = self.dn_weight * torch.nan_to_num(loss_cls)
        loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox)

        return loss_cls, loss_bbox

    @force_fp32(apply_to=('preds_dicts'))
    def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers):
        known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \
            self.prepare_for_dn_loss(preds_dicts['dn_mask_dict'])

        all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)]
        all_known_labels_list = [known_labels for _ in range(num_dec_layers)]
        all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)]

        dn_losses_cls, dn_losses_bbox = multi_apply(
            self.dn_loss_single, cls_scores, bbox_preds,
            all_known_bboxs_list, all_known_labels_list, all_num_tgts_list)

        loss_dict['loss_cls_dn'] = dn_losses_cls[-1]
        loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1]

        num_dec_layer = 0
        for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):
            loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i
            loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i
            num_dec_layer += 1

        return loss_dict

    def _get_target_single(self,
                           cls_score,
                           bbox_pred,
                           gt_labels,
                           gt_bboxes,
                           gt_bboxes_ignore=None):
        num_bboxes = bbox_pred.size(0)

        # assigner and sampler
        assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True)
        sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds

        # label targets
        labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long)
        labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
        label_weights = gt_bboxes.new_ones(num_bboxes)

        # bbox targets
        bbox_targets = torch.zeros_like(bbox_pred)[..., :9]
        bbox_weights = torch.zeros_like(bbox_pred)
        bbox_weights[pos_inds] = 1.0
        
        # DETR
        bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)

    def get_targets(self,
                    cls_scores_list,
                    bbox_preds_list,
                    gt_bboxes_list,
                    gt_labels_list,
                    gt_bboxes_ignore_list=None):
        assert gt_bboxes_ignore_list is None, \
            'Only supports for gt_bboxes_ignore setting to None.'
        num_imgs = len(cls_scores_list)
        gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]

        (labels_list, label_weights_list, bbox_targets_list,
         bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
                self._get_target_single, cls_scores_list, bbox_preds_list,
             gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list)
        num_total_pos = sum((inds.numel() for inds in pos_inds_list))
        num_total_neg = sum((inds.numel() for inds in neg_inds_list))
        return (labels_list, label_weights_list, bbox_targets_list,
                bbox_weights_list, num_total_pos, num_total_neg)

    def loss_single(self,
                    cls_scores,
                    bbox_preds,
                    gt_bboxes_list,
                    gt_labels_list,
                    gt_bboxes_ignore_list=None):
        num_imgs = cls_scores.size(0)
        cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
        bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
        cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
                gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list)
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        labels = torch.cat(labels_list, 0)
        label_weights = torch.cat(label_weights_list, 0)
        bbox_targets = torch.cat(bbox_targets_list, 0)
        bbox_weights = torch.cat(bbox_weights_list, 0)

        # classification loss
        cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
        # construct weighted avg_factor to match with the official DETR repo
        cls_avg_factor = num_total_pos * 1.0 + \
            num_total_neg * self.bg_cls_weight
        if self.sync_cls_avg_factor:
            cls_avg_factor = reduce_mean(
                cls_scores.new_tensor([cls_avg_factor]))

        cls_avg_factor = max(cls_avg_factor, 1)
        loss_cls = self.loss_cls(
            cls_scores, labels, label_weights, avg_factor=cls_avg_factor)

        # Compute the average number of gt boxes accross all gpus, for
        # normalization purposes
        num_total_pos = loss_cls.new_tensor([num_total_pos])
        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()

        # regression L1 loss
        bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
        normalized_bbox_targets = normalize_bbox(bbox_targets)
        isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
        bbox_weights = bbox_weights * self.code_weights

        loss_bbox = self.loss_bbox(
            bbox_preds[isnotnan, :10],
            normalized_bbox_targets[isnotnan, :10],
            bbox_weights[isnotnan, :10],
            avg_factor=num_total_pos
        )

        loss_cls = torch.nan_to_num(loss_cls)
        loss_bbox = torch.nan_to_num(loss_bbox)
        
        return loss_cls, loss_bbox

    @force_fp32(apply_to=('preds_dicts'))
    def loss(self,
             gt_bboxes_list,
             gt_labels_list,
             preds_dicts,
             gt_bboxes_ignore=None):
        assert gt_bboxes_ignore is None, \
            f'{self.__class__.__name__} only supports ' \
            f'for gt_bboxes_ignore setting to None.'

        all_cls_scores = preds_dicts['all_cls_scores']
        all_bbox_preds = preds_dicts['all_bbox_preds']
        enc_cls_scores = preds_dicts['enc_cls_scores']
        enc_bbox_preds = preds_dicts['enc_bbox_preds']

        num_dec_layers = len(all_cls_scores)
        device = gt_labels_list[0].device
        gt_bboxes_list = [torch.cat(
            (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
            dim=1).to(device) for gt_bboxes in gt_bboxes_list]

        all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
        all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
        all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]

        losses_cls, losses_bbox = multi_apply(
            self.loss_single, all_cls_scores, all_bbox_preds,
            all_gt_bboxes_list, all_gt_labels_list, 
            all_gt_bboxes_ignore_list)

        loss_dict = dict()
        # loss of proposal generated from encode feature map
        if enc_cls_scores is not None:
            binary_labels_list = [
                torch.zeros_like(gt_labels_list[i])
                for i in range(len(all_gt_labels_list))
            ]
            enc_loss_cls, enc_losses_bbox = \
                self.loss_single(enc_cls_scores, enc_bbox_preds,
                                 gt_bboxes_list, binary_labels_list, gt_bboxes_ignore)
            loss_dict['enc_loss_cls'] = enc_loss_cls
            loss_dict['enc_loss_bbox'] = enc_losses_bbox

        if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None:
            loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers)

        # loss from the last decoder layer
        loss_dict['loss_cls'] = losses_cls[-1]
        loss_dict['loss_bbox'] = losses_bbox[-1]

        # loss from other decoder layers
        num_dec_layer = 0
        for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]):
            loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
            loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
            num_dec_layer += 1
        
        outputs_wm_temporal = preds_dicts['wm_temporal']
        losses_wm_temporal = []
        for idx, output_wm_temporal in enumerate(outputs_wm_temporal):
            value = torch.nan_to_num(self.loss_wm_temporal(output_wm_temporal['prior'], output_wm_temporal['posterior']))
            loss_dict[f'l{idx}.loss_wm_t'] = value
            losses_wm_temporal.append(value)
        loss_dict['loss_wm_t'] = sum(losses_wm_temporal)

        outputs_wm_spatial = preds_dicts['wm_spatial']
        losses_wm_spatial = []
        for idx, output_wm_spatial in enumerate(outputs_wm_spatial):
            value = torch.nan_to_num(self.loss_wm_spatial(output_wm_spatial['prior'], output_wm_spatial['posterior']))
            loss_dict[f'l{idx}.loss_wm_s'] = value
            losses_wm_spatial.append(value)
        loss_dict['loss_wm_s'] = sum(losses_wm_spatial)
        return loss_dict

    @force_fp32(apply_to=('preds_dicts'))
    def get_bboxes(self, preds_dicts, img_metas, rescale=False):
        preds_dicts = self.bbox_coder.decode(preds_dicts)
        num_samples = len(preds_dicts)
        ret_list = []
        for i in range(num_samples):
            preds = preds_dicts[i]
            bboxes = preds['bboxes']
            bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5

            if VERSION.name == 'v0.17.1':
                import copy
                w, l = copy.deepcopy(bboxes[:, 3]), copy.deepcopy(bboxes[:, 4])
                bboxes[:, 3], bboxes[:, 4] = l, w
                bboxes[:, 6] = -bboxes[:, 6] - math.pi / 2

            bboxes = LiDARInstance3DBoxes(bboxes, 9)
            scores = preds['scores']
            labels = preds['labels']
            ret_list.append([bboxes, scores, labels])
        return ret_list
