import torch
import torch.nn as nn
from mmcv.cnn import Linear, bias_init_with_prob
from mmcv.runner import force_fp32
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmdet.models import HEADS, build_loss
from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
from mmdet.models.utils import NormedLinear, build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid

from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox
from projects.mmdet3d_plugin.models.utils.misc import (
    MLN,
    SELayer_Linear,
    memory_refresh,
    topk_gather,
    transform_reference_points,
)
from projects.mmdet3d_plugin.models.utils.positional_encoding import (
    nerf_positional_encoding,
    pos2posemb1d,
    pos2posemb3d,
)


@HEADS.register_module()
class StreamPETRHeadPruning(AnchorFreeHead):
    _version = 2

    def __init__(
        self,
        num_classes,
        in_channels=256,
        stride=16,
        embed_dims=256,
        num_query=100,
        num_reg_fcs=2,
        memory_len=1024,
        topk_proposals=256,
        num_propagated=256,
        with_dn=True,
        with_ego_pos=True,
        match_with_velo=True,
        match_costs=None,
        transformer=None,
        sync_cls_avg_factor=False,
        code_weights=None,
        bbox_coder=None,
        loss_cls=dict(
            type="CrossEntropyLoss",
            bg_cls_weight=0.1,
            use_sigmoid=False,
            loss_weight=1.0,
            class_weight=1.0,
        ),
        loss_bbox=dict(type="L1Loss", loss_weight=5.0),
        train_cfg=dict(
            assigner=dict(
                type="HungarianAssigner3D",
                cls_cost=dict(type="ClassificationCost", weight=1.0),
                reg_cost=dict(type="BBoxL1Cost", weight=5.0),
            ),
        ),
        test_cfg=dict(max_per_img=100),
        depth_step=0.8,
        depth_num=64,
        LID=False,
        depth_start=1,
        position_range=[-65, -65, -8.0, 65, 65, 8.0],
        scalar=5,
        noise_scale=0.4,
        noise_trans=0.0,
        dn_weight=1.0,
        split=0.5,
        init_cfg=None,
        normedlinear=False,
        **kwargs,
    ):
        # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
        # since it brings inconvenience when the initialization of
        # `AnchorFreeHead` is called.
        if "code_size" in kwargs:
            self.code_size = kwargs["code_size"]
        else:
            self.code_size = 10
        if code_weights is not None:
            self.code_weights = code_weights
        else:
            self.code_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]

        self.code_weights = self.code_weights[: self.code_size]

        if match_costs is not None:
            self.match_costs = match_costs
        else:
            self.match_costs = self.code_weights

        self.bg_cls_weight = 0
        self.sync_cls_avg_factor = sync_cls_avg_factor
        class_weight = loss_cls.get("class_weight", None)
        if class_weight is not None and (self.__class__ is StreamPETRHeadPruning):
            assert isinstance(class_weight, float), (
                "Expected "
                "class_weight to have type float. Found "
                f"{type(class_weight)}."
            )
            # NOTE following the official DETR rep0, bg_cls_weight means
            # relative classification weight of the no-object class.
            bg_cls_weight = loss_cls.get("bg_cls_weight", class_weight)
            assert isinstance(bg_cls_weight, float), (
                "Expected "
                "bg_cls_weight to have type float. Found "
                f"{type(bg_cls_weight)}."
            )
            class_weight = torch.ones(num_classes + 1) * class_weight
            # set background class as the last indice
            class_weight[num_classes] = bg_cls_weight
            loss_cls.update({"class_weight": class_weight})
            if "bg_cls_weight" in loss_cls:
                loss_cls.pop("bg_cls_weight")
            self.bg_cls_weight = bg_cls_weight

        if train_cfg:
            assert "assigner" in train_cfg, (
                "assigner should be provided " "when train_cfg is set."
            )
            assigner = train_cfg["assigner"]

            self.assigner = build_assigner(assigner)
            # DETR sampling=False, so use PseudoSampler
            sampler_cfg = dict(type="PseudoSampler")
            self.sampler = build_sampler(sampler_cfg, context=self)

        self.num_query = num_query
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.memory_len = memory_len
        self.topk_proposals = topk_proposals
        self.num_propagated = num_propagated
        self.with_dn = with_dn
        self.with_ego_pos = with_ego_pos
        self.match_with_velo = match_with_velo
        self.num_reg_fcs = num_reg_fcs
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.fp16_enabled = False
        self.embed_dims = embed_dims
        self.depth_step = depth_step
        self.depth_num = depth_num
        self.position_dim = depth_num * 3
        self.LID = LID
        self.depth_start = depth_start
        self.stride = stride

        self.scalar = scalar
        self.bbox_noise_scale = noise_scale
        self.bbox_noise_trans = noise_trans
        self.dn_weight = dn_weight
        self.split = split

        self.act_cfg = transformer.get("act_cfg", dict(type="ReLU", inplace=True))
        self.num_pred = 6
        self.normedlinear = normedlinear
        super(StreamPETRHeadPruning, self).__init__(
            num_classes, in_channels, init_cfg=init_cfg
        )

        self.loss_cls = build_loss(loss_cls)
        self.loss_bbox = build_loss(loss_bbox)

        if self.loss_cls.use_sigmoid:
            self.cls_out_channels = num_classes
        else:
            self.cls_out_channels = num_classes + 1

        self.transformer = build_transformer(transformer)

        self.code_weights = nn.Parameter(
            torch.tensor(self.code_weights), requires_grad=False
        )

        self.match_costs = nn.Parameter(
            torch.tensor(self.match_costs), requires_grad=False
        )

        self.bbox_coder = build_bbox_coder(bbox_coder)

        self.pc_range = nn.Parameter(
            torch.tensor(self.bbox_coder.pc_range), requires_grad=False
        )

        self.position_range = nn.Parameter(
            torch.tensor(position_range), requires_grad=False
        )

        if self.LID:
            index = torch.arange(start=0, end=self.depth_num, step=1).float()
            index_1 = index + 1
            bin_size = (self.position_range[3] - self.depth_start) / (
                self.depth_num * (1 + self.depth_num)
            )
            coords_d = self.depth_start + bin_size * index * index_1
        else:
            index = torch.arange(start=0, end=self.depth_num, step=1).float()
            bin_size = (self.position_range[3] - self.depth_start) / self.depth_num
            coords_d = self.depth_start + bin_size * index

        self.coords_d = nn.Parameter(coords_d, requires_grad=False)

        self._init_layers()
        self.reset_memory()

    def _init_layers(self):
        """Initialize layers of the transformer head."""

        cls_branch = []
        for _ in range(self.num_reg_fcs):
            cls_branch.append(Linear(self.embed_dims, self.embed_dims))
            cls_branch.append(nn.LayerNorm(self.embed_dims))
            cls_branch.append(nn.ReLU(inplace=True))
        if self.normedlinear:
            cls_branch.append(NormedLinear(self.embed_dims, self.cls_out_channels))
        else:
            cls_branch.append(Linear(self.embed_dims, self.cls_out_channels))
        fc_cls = nn.Sequential(*cls_branch)

        reg_branch = []
        for _ in range(self.num_reg_fcs):
            reg_branch.append(Linear(self.embed_dims, self.embed_dims))
            reg_branch.append(nn.ReLU())
        reg_branch.append(Linear(self.embed_dims, self.code_size))
        reg_branch = nn.Sequential(*reg_branch)

        self.cls_branches = nn.ModuleList([fc_cls for _ in range(self.num_pred)])
        self.reg_branches = nn.ModuleList([reg_branch for _ in range(self.num_pred)])

        self.position_encoder = nn.Sequential(
            nn.Linear(self.position_dim, self.embed_dims * 4),
            nn.ReLU(),
            nn.Linear(self.embed_dims * 4, self.embed_dims),
        )

        self.memory_embed = nn.Sequential(
            nn.Linear(self.in_channels, self.embed_dims),
            nn.ReLU(),
            nn.Linear(self.embed_dims, self.embed_dims),
        )

        # can be replaced with MLN
        self.featurized_pe = SELayer_Linear(self.embed_dims)

        self.deprecated_reference_points = nn.ParameterList([])
        self.reference_points = nn.ParameterList(
            [nn.Parameter(torch.empty(3)) for _ in range(self.num_query)]
        )
        if self.num_propagated > 0:
            self.pseudo_reference_points = nn.ParameterList(
                [nn.Parameter(torch.empty(3)) for _ in range(self.num_propagated)]
            )

        self.query_embedding = nn.Sequential(
            nn.Linear(self.embed_dims * 3 // 2, self.embed_dims),
            nn.ReLU(),
            nn.Linear(self.embed_dims, self.embed_dims),
        )

        self.spatial_alignment = MLN(8)

        self.time_embedding = nn.Sequential(
            nn.Linear(self.embed_dims, self.embed_dims), nn.LayerNorm(self.embed_dims)
        )

        # encoding ego pose
        if self.with_ego_pos:
            self.ego_pose_pe = MLN(180)
            self.ego_pose_memory = MLN(180)

    def init_weights(self):
        """Initialize weights of the transformer head."""
        # The initialization for transformer is important
        for i in range(len(self.reference_points)):
            nn.init.uniform_(self.reference_points[i], 0, 1)
        if self.num_propagated > 0:
            for i in range(len(self.pseudo_reference_points)):
                nn.init.uniform_(self.pseudo_reference_points[i], 0, 1)
                self.pseudo_reference_points[i].requires_grad = False

        self.transformer.init_weights()
        if self.loss_cls.use_sigmoid:
            bias_init = bias_init_with_prob(0.01)
            for m in self.cls_branches:
                nn.init.constant_(m[-1].bias, bias_init)

    def reset_memory(self):
        self.memory_embedding = None
        self.memory_reference_point = None
        self.memory_timestamp = None
        self.memory_egopose = None
        self.memory_velo = None

    def pre_update_memory(self, data):
        x = data["prev_exists"]
        B = x.size(0)
        # refresh the memory when the scene changes
        if self.memory_embedding is None:
            self.memory_embedding = x.new_zeros(B, self.memory_len, self.embed_dims)
            self.memory_reference_point = x.new_zeros(B, self.memory_len, 3)
            self.memory_timestamp = x.new_zeros(B, self.memory_len, 1)
            self.memory_egopose = x.new_zeros(B, self.memory_len, 4, 4)
            self.memory_velo = x.new_zeros(B, self.memory_len, 2)
        else:
            self.memory_timestamp += data["timestamp"].unsqueeze(-1).unsqueeze(-1)
            self.memory_egopose = (
                data["ego_pose_inv"].unsqueeze(1) @ self.memory_egopose
            )
            self.memory_reference_point = transform_reference_points(
                self.memory_reference_point, data["ego_pose_inv"], reverse=False
            )
            self.memory_timestamp = memory_refresh(
                self.memory_timestamp[:, : self.memory_len], x
            )
            self.memory_reference_point = memory_refresh(
                self.memory_reference_point[:, : self.memory_len], x
            )
            self.memory_embedding = memory_refresh(
                self.memory_embedding[:, : self.memory_len], x
            )
            self.memory_egopose = memory_refresh(
                self.memory_egopose[:, : self.memory_len], x
            )
            self.memory_velo = memory_refresh(self.memory_velo[:, : self.memory_len], x)

        # for the first frame, padding pseudo_reference_points (non-learnable)
        if self.num_propagated > 0:
            pseudo_reference_points_ = torch.stack(
                [point for point in self.pseudo_reference_points[: self.num_propagated]]
            )
            pseudo_reference_points = (
                pseudo_reference_points_ * (self.pc_range[3:6] - self.pc_range[0:3])
                + self.pc_range[0:3]
            )
            self.memory_reference_point[:, : self.num_propagated] = (
                self.memory_reference_point[:, : self.num_propagated]
                + (1 - x).view(B, 1, 1) * pseudo_reference_points
            )
            self.memory_egopose[:, : self.num_propagated] = self.memory_egopose[
                :, : self.num_propagated
            ] + (1 - x).view(B, 1, 1, 1) * torch.eye(4, device=x.device)

    def post_update_memory(
        self, data, rec_ego_pose, all_cls_scores, all_bbox_preds, outs_dec, mask_dict
    ):
        if self.training and mask_dict and mask_dict["pad_size"] > 0:
            rec_reference_points = all_bbox_preds[:, :, mask_dict["pad_size"] :, :3][-1]
            rec_velo = all_bbox_preds[:, :, mask_dict["pad_size"] :, -2:][-1]
            rec_memory = outs_dec[:, :, mask_dict["pad_size"] :, :][-1]
            rec_score = (
                all_cls_scores[:, :, mask_dict["pad_size"] :, :][-1]
                .sigmoid()
                .topk(1, dim=-1)
                .values[..., 0:1]
            )
            rec_timestamp = torch.zeros_like(rec_score, dtype=torch.float64)
        else:
            rec_reference_points = all_bbox_preds[..., :3][-1]
            rec_velo = all_bbox_preds[..., -2:][-1]
            rec_memory = outs_dec[-1]
            rec_score = all_cls_scores[-1].sigmoid().topk(1, dim=-1).values[..., 0:1]
            rec_timestamp = torch.zeros_like(rec_score, dtype=torch.float64)

        # topk proposals
        num_proposals = min(rec_score.shape[1], self.topk_proposals)
        _, topk_indexes = torch.topk(rec_score, num_proposals, dim=1)
        rec_timestamp = topk_gather(rec_timestamp, topk_indexes)
        rec_reference_points = topk_gather(rec_reference_points, topk_indexes).detach()
        rec_memory = topk_gather(rec_memory, topk_indexes).detach()
        rec_ego_pose = topk_gather(rec_ego_pose, topk_indexes)
        rec_velo = topk_gather(rec_velo, topk_indexes).detach()

        self.memory_embedding = torch.cat([rec_memory, self.memory_embedding], dim=1)
        self.memory_timestamp = torch.cat([rec_timestamp, self.memory_timestamp], dim=1)
        self.memory_egopose = torch.cat([rec_ego_pose, self.memory_egopose], dim=1)
        self.memory_reference_point = torch.cat(
            [rec_reference_points, self.memory_reference_point], dim=1
        )
        self.memory_velo = torch.cat([rec_velo, self.memory_velo], dim=1)
        self.memory_reference_point = transform_reference_points(
            self.memory_reference_point, data["ego_pose"], reverse=False
        )
        self.memory_timestamp -= data["timestamp"].unsqueeze(-1).unsqueeze(-1)
        self.memory_egopose = data["ego_pose"].unsqueeze(1) @ self.memory_egopose

    def position_embeding(self, data, memory_centers, topk_indexes, img_metas):
        eps = 1e-5
        BN, H, W, _ = memory_centers.shape
        B = data["intrinsics"].size(0)

        intrinsic = torch.stack(
            [data["intrinsics"][..., 0, 0], data["intrinsics"][..., 1, 1]], dim=-1
        )
        intrinsic = torch.abs(intrinsic) / 1e3
        intrinsic = intrinsic.repeat(1, H * W, 1).view(B, -1, 2)
        LEN = intrinsic.size(1)

        num_sample_tokens = topk_indexes.size(1) if topk_indexes is not None else LEN

        pad_h, pad_w, _ = img_metas[0]["pad_shape"][0]
        memory_centers[..., 0] = memory_centers[..., 0] * pad_w
        memory_centers[..., 1] = memory_centers[..., 1] * pad_h

        D = self.coords_d.shape[0]

        memory_centers = memory_centers.detach().view(B, LEN, 1, 2)
        topk_centers = topk_gather(memory_centers, topk_indexes).repeat(1, 1, D, 1)
        coords_d = self.coords_d.view(1, 1, D, 1).repeat(B, num_sample_tokens, 1, 1)
        coords = torch.cat([topk_centers, coords_d], dim=-1)
        coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)
        coords[..., :2] = coords[..., :2] * torch.maximum(
            coords[..., 2:3], torch.ones_like(coords[..., 2:3]) * eps
        )

        coords = coords.unsqueeze(-1)

        dtype = data["lidar2img"].dtype
        device = data["lidar2img"].device
        img2lidars = data["lidar2img"].cpu().inverse().to(dtype=dtype, device=device)
        img2lidars = (
            img2lidars.view(BN, 1, 1, 4, 4)
            .repeat(1, H * W, D, 1, 1)
            .view(B, LEN, D, 4, 4)
        )
        img2lidars = topk_gather(img2lidars, topk_indexes)

        coords3d = torch.matmul(img2lidars, coords).squeeze(-1)[..., :3]
        coords3d[..., 0:3] = (coords3d[..., 0:3] - self.position_range[0:3]) / (
            self.position_range[3:6] - self.position_range[0:3]
        )
        coords3d = coords3d.reshape(B, -1, D * 3)

        pos_embed = inverse_sigmoid(coords3d)
        coords_position_embeding = self.position_encoder(pos_embed)
        intrinsic = topk_gather(intrinsic, topk_indexes)

        # for spatial alignment in focal petr
        cone = torch.cat(
            [intrinsic, coords3d[..., -3:], coords3d[..., -90:-87]], dim=-1
        )

        return coords_position_embeding, cone

    def temporal_alignment(self, query_pos, tgt, reference_points):
        B = query_pos.size(0)

        temp_reference_point = (self.memory_reference_point - self.pc_range[:3]) / (
            self.pc_range[3:6] - self.pc_range[0:3]
        )
        temp_pos = self.query_embedding(pos2posemb3d(temp_reference_point))
        temp_memory = self.memory_embedding
        rec_ego_pose = (
            torch.eye(4, device=query_pos.device)
            .unsqueeze(0)
            .unsqueeze(0)
            .repeat(B, query_pos.size(1), 1, 1)
        )

        if self.with_ego_pos:
            rec_ego_motion = torch.cat(
                [
                    torch.zeros_like(reference_points[..., :3]),
                    rec_ego_pose[..., :3, :].flatten(-2),
                ],
                dim=-1,
            )
            rec_ego_motion = nerf_positional_encoding(rec_ego_motion)
            tgt = self.ego_pose_memory(tgt, rec_ego_motion)
            query_pos = self.ego_pose_pe(query_pos, rec_ego_motion)
            memory_ego_motion = torch.cat(
                [
                    self.memory_velo,
                    self.memory_timestamp,
                    self.memory_egopose[..., :3, :].flatten(-2),
                ],
                dim=-1,
            ).float()
            memory_ego_motion = nerf_positional_encoding(memory_ego_motion)
            temp_pos = self.ego_pose_pe(temp_pos, memory_ego_motion)
            temp_memory = self.ego_pose_memory(temp_memory, memory_ego_motion)

        query_pos += self.time_embedding(
            pos2posemb1d(torch.zeros_like(reference_points[..., :1]))
        )
        temp_pos += self.time_embedding(pos2posemb1d(self.memory_timestamp).float())

        if self.num_propagated > 0:
            tgt = torch.cat([tgt, temp_memory[:, : self.num_propagated]], dim=1)
            query_pos = torch.cat(
                [query_pos, temp_pos[:, : self.num_propagated]], dim=1
            )
            reference_points = torch.cat(
                [reference_points, temp_reference_point[:, : self.num_propagated]],
                dim=1,
            )
            rec_ego_pose = (
                torch.eye(4, device=query_pos.device)
                .unsqueeze(0)
                .unsqueeze(0)
                .repeat(B, query_pos.shape[1] + self.num_propagated, 1, 1)
            )
            temp_memory = temp_memory[:, self.num_propagated :]
            temp_pos = temp_pos[:, self.num_propagated :]

        return tgt, query_pos, reference_points, temp_memory, temp_pos, rec_ego_pose

    def prepare_for_dn(self, batch_size, reference_points, img_metas):
        if self.training and self.with_dn:
            targets = [
                torch.cat(
                    (
                        img_meta["gt_bboxes_3d"]._data.gravity_center,
                        img_meta["gt_bboxes_3d"]._data.tensor[:, 3:],
                    ),
                    dim=1,
                )
                for img_meta in img_metas
            ]
            labels = [img_meta["gt_labels_3d"]._data for img_meta in img_metas]
            known = [(torch.ones_like(t)).cuda() for t in labels]
            know_idx = known
            unmask_bbox = unmask_label = torch.cat(known)
            # gt_num
            known_num = [t.size(0) for t in targets]

            labels = torch.cat([t for t in labels])
            boxes = torch.cat([t for t in targets])
            batch_idx = torch.cat(
                [torch.full((t.size(0),), i) for i, t in enumerate(targets)]
            )

            known_indice = torch.nonzero(unmask_label + unmask_bbox)
            known_indice = known_indice.view(-1)
            # add noise
            known_indice = known_indice.repeat(self.scalar, 1).view(-1)
            known_labels = (
                labels.repeat(self.scalar, 1)
                .view(-1)
                .long()
                .to(reference_points.device)
            )
            known_bid = batch_idx.repeat(self.scalar, 1).view(-1)
            known_bboxs = boxes.repeat(self.scalar, 1).to(reference_points.device)
            known_bbox_center = known_bboxs[:, :3].clone()
            known_bbox_scale = known_bboxs[:, 3:6].clone()

            if self.bbox_noise_scale > 0:
                diff = known_bbox_scale / 2 + self.bbox_noise_trans
                rand_prob = torch.rand_like(known_bbox_center) * 2 - 1.0
                known_bbox_center += torch.mul(rand_prob, diff) * self.bbox_noise_scale
                known_bbox_center[..., 0:3] = (
                    known_bbox_center[..., 0:3] - self.pc_range[0:3]
                ) / (self.pc_range[3:6] - self.pc_range[0:3])

                known_bbox_center = known_bbox_center.clamp(min=0.0, max=1.0)
                mask = torch.norm(rand_prob, 2, 1) > self.split
                known_labels[mask] = self.num_classes

            single_pad = int(max(known_num))
            pad_size = int(single_pad * self.scalar)
            padding_bbox = torch.zeros(pad_size, 3).to(reference_points.device)
            padded_reference_points = (
                torch.cat([padding_bbox, reference_points], dim=0)
                .unsqueeze(0)
                .repeat(batch_size, 1, 1)
            )

            if len(known_num):
                map_known_indice = torch.cat(
                    [torch.tensor(range(num)) for num in known_num]
                )  # [1,2, 1,2,3]
                map_known_indice = torch.cat(
                    [map_known_indice + single_pad * i for i in range(self.scalar)]
                ).long()
            if len(known_bid):
                padded_reference_points[(known_bid.long(), map_known_indice)] = (
                    known_bbox_center.to(reference_points.device)
                )

            tgt_size = pad_size + self.num_query
            attn_mask = torch.ones(tgt_size, tgt_size).to(reference_points.device) < 0
            # match query cannot see the reconstruct
            attn_mask[pad_size:, :pad_size] = True
            # reconstruct cannot see each other
            for i in range(self.scalar):
                if i == 0:
                    attn_mask[
                        single_pad * i : single_pad * (i + 1),
                        single_pad * (i + 1) : pad_size,
                    ] = True
                if i == self.scalar - 1:
                    attn_mask[
                        single_pad * i : single_pad * (i + 1), : single_pad * i
                    ] = True
                else:
                    attn_mask[
                        single_pad * i : single_pad * (i + 1),
                        single_pad * (i + 1) : pad_size,
                    ] = True
                    attn_mask[
                        single_pad * i : single_pad * (i + 1), : single_pad * i
                    ] = True

            # update dn mask for temporal modeling
            query_size = pad_size + self.num_query + self.num_propagated
            tgt_size = pad_size + self.num_query + self.memory_len
            temporal_attn_mask = (
                torch.ones(query_size, tgt_size).to(reference_points.device) < 0
            )
            temporal_attn_mask[: attn_mask.size(0), : attn_mask.size(1)] = attn_mask
            temporal_attn_mask[pad_size:, :pad_size] = True
            attn_mask = temporal_attn_mask

            mask_dict = {
                "known_indice": torch.as_tensor(known_indice).long(),
                "batch_idx": torch.as_tensor(batch_idx).long(),
                "map_known_indice": torch.as_tensor(map_known_indice).long(),
                "known_lbs_bboxes": (known_labels, known_bboxs),
                "know_idx": know_idx,
                "pad_size": pad_size,
            }

        else:
            padded_reference_points = reference_points.unsqueeze(0).repeat(
                batch_size, 1, 1
            )
            attn_mask = None
            mask_dict = None

        return padded_reference_points, attn_mask, mask_dict

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        """load checkpoints."""
        # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
        # since `AnchorFreeHead._load_from_state_dict` should not be
        # called here. Invoking the default `Module._load_from_state_dict`
        # is enough.

        # Names of some parameters in has been changed.
        version = local_metadata.get("version", None)
        if (version is None or version < 2) and self.__class__ is StreamPETRHeadPruning:
            convert_dict = {
                ".self_attn.": ".attentions.0.",
                # '.ffn.': '.ffns.0.',
                ".multihead_attn.": ".attentions.1.",
                ".decoder.norm.": ".decoder.post_norm.",
            }
            state_dict_keys = list(state_dict.keys())
            for k in state_dict_keys:
                for ori_key, convert_key in convert_dict.items():
                    if ori_key in k:
                        convert_key = k.replace(ori_key, convert_key)
                        state_dict[convert_key] = state_dict[k]
                        del state_dict[k]

        if "pts_bbox_head.reference_points.weight" in state_dict:
            reference_points = state_dict["pts_bbox_head.reference_points.weight"]
            for i in range(self.num_query):
                state_dict[f"pts_bbox_head.reference_points.{i}"] = reference_points[i]
            state_dict.pop("pts_bbox_head.reference_points.weight")

        if "pts_bbox_head.pseudo_reference_points.weight" in state_dict:
            pseudo_reference_points = state_dict[
                "pts_bbox_head.pseudo_reference_points.weight"
            ]
            for i in range(self.num_propagated):
                state_dict[f"pts_bbox_head.pseudo_reference_points.{i}"] = (
                    pseudo_reference_points[i]
                )
            state_dict.pop("pts_bbox_head.pseudo_reference_points.weight")

        deprecated_keys = []
        for key_ in state_dict:
            if key_.startswith("pts_bbox_head.deprecated_reference_points"):
                deprecated_keys.append(key_)
        for key_ in deprecated_keys:
            state_dict.pop(key_)

        super(AnchorFreeHead, self)._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def forward(self, memory_center, img_metas, topk_indexes=None, **data):
        """Forward function.
        Args:
            memory_centr (Tensor):
                Features from the upstream network,
                each is a 5D-tensor with shape (B, N, C, H, W).
            topk_indexes (Tensor, optional):
                has shape []
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head,
                shape [nb_dec, bs, num_query, cls_out_channels].
            all_bbox_preds (Tensor): Sigmoid outputs from the regression
                head with shape [nb_dec, bs, num_query, 9]
                and normalized coordinate format (cx, cy, cz, w, l, h, theta, vx, vy).
        """
        # zero init the memory bank
        self.pre_update_memory(data)

        x = data["img_feats"]
        B, N, C, H, W = x.shape
        num_tokens = N * H * W
        memory = x.permute(0, 1, 3, 4, 2).reshape(B, num_tokens, C)
        memory = topk_gather(memory, topk_indexes)

        pos_embed, cone = self.position_embeding(
            data, memory_center, topk_indexes, img_metas
        )

        memory = self.memory_embed(memory)

        # spatial_alignment in focal petr
        memory = self.spatial_alignment(memory, cone)
        pos_embed = self.featurized_pe(pos_embed, memory)

        reference_points = torch.stack([point for point in self.reference_points])
        reference_points, attn_mask, mask_dict = self.prepare_for_dn(
            B, reference_points, img_metas
        )
        query_pos = self.query_embedding(pos2posemb3d(reference_points))
        tgt = torch.zeros_like(query_pos)

        # prepare for the tgt and query_pos using mln.
        tgt, query_pos, reference_points, temp_memory, temp_pos, rec_ego_pose = (
            self.temporal_alignment(query_pos, tgt, reference_points)
        )

        # transformer here is a little different from PETR
        outs_dec, _ = self.transformer(
            memory, tgt, query_pos, pos_embed, attn_mask, temp_memory, temp_pos
        )

        outs_dec = torch.nan_to_num(outs_dec)
        outputs_classes = []
        outputs_coords = []
        for lvl in range(outs_dec.shape[0]):
            reference = inverse_sigmoid(reference_points.clone())
            assert reference.shape[-1] == 3
            outputs_class = self.cls_branches[lvl](outs_dec[lvl])
            tmp = self.reg_branches[lvl](outs_dec[lvl])

            tmp[..., 0:3] += reference[..., 0:3]
            tmp[..., 0:3] = tmp[..., 0:3].sigmoid()

            outputs_coord = tmp
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)

        all_cls_scores = torch.stack(outputs_classes)
        all_bbox_preds = torch.stack(outputs_coords)
        all_bbox_preds[..., 0:3] = (
            all_bbox_preds[..., 0:3] * (self.pc_range[3:6] - self.pc_range[0:3])
            + self.pc_range[0:3]
        )

        # update the memory bank
        self.post_update_memory(
            data, rec_ego_pose, all_cls_scores, all_bbox_preds, outs_dec, mask_dict
        )

        if mask_dict and mask_dict["pad_size"] > 0:
            output_known_class = all_cls_scores[:, :, : mask_dict["pad_size"], :]
            output_known_coord = all_bbox_preds[:, :, : mask_dict["pad_size"], :]
            outputs_class = all_cls_scores[:, :, mask_dict["pad_size"] :, :]
            outputs_coord = all_bbox_preds[:, :, mask_dict["pad_size"] :, :]
            mask_dict["output_known_lbs_bboxes"] = (
                output_known_class,
                output_known_coord,
            )
            outs = {
                "all_cls_scores": outputs_class,
                "all_bbox_preds": outputs_coord,
                "dn_mask_dict": mask_dict,
            }
        else:
            outs = {
                "all_cls_scores": all_cls_scores,
                "all_bbox_preds": all_bbox_preds,
                "dn_mask_dict": None,
            }

        return outs

    def prepare_for_loss(self, mask_dict):
        """
        prepare dn components to calculate loss
        Args:
            mask_dict: a dict that contains dn information
        """
        output_known_class, output_known_coord = mask_dict["output_known_lbs_bboxes"]
        known_labels, known_bboxs = mask_dict["known_lbs_bboxes"]
        map_known_indice = mask_dict["map_known_indice"].long()
        known_indice = mask_dict["known_indice"].long().cpu()
        batch_idx = mask_dict["batch_idx"].long()
        bid = batch_idx[known_indice]
        if len(output_known_class) > 0:
            output_known_class = output_known_class.permute(1, 2, 0, 3)[
                (bid, map_known_indice)
            ].permute(1, 0, 2)
            output_known_coord = output_known_coord.permute(1, 2, 0, 3)[
                (bid, map_known_indice)
            ].permute(1, 0, 2)
        num_tgt = known_indice.numel()
        return (
            known_labels,
            known_bboxs,
            output_known_class,
            output_known_coord,
            num_tgt,
        )

    def _get_target_single(self, cls_score, bbox_pred, gt_labels, gt_bboxes):

        num_bboxes = bbox_pred.size(0)
        # assigner and sampler

        assign_result = self.assigner.assign(
            bbox_pred,
            cls_score,
            gt_bboxes,
            gt_labels,
            self.match_costs,
            self.match_with_velo,
        )
        sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds

        # label targets
        labels = gt_bboxes.new_full((num_bboxes,), self.num_classes, dtype=torch.long)
        label_weights = gt_bboxes.new_ones(num_bboxes)

        # bbox targets
        code_size = gt_bboxes.size(1)
        bbox_targets = torch.zeros_like(bbox_pred)[..., :code_size]
        bbox_weights = torch.zeros_like(bbox_pred)
        # print(gt_bboxes.size(), bbox_pred.size())
        # DETR
        if sampling_result.num_gts > 0:
            bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
            bbox_weights[pos_inds] = 1.0
            labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)

    def get_targets(
        self,
        cls_scores_list,
        bbox_preds_list,
        gt_bboxes_list,
        gt_labels_list,
    ):

        (
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            pos_inds_list,
            neg_inds_list,
        ) = multi_apply(
            self._get_target_single,
            cls_scores_list,
            bbox_preds_list,
            gt_labels_list,
            gt_bboxes_list,
        )
        num_total_pos = sum((inds.numel() for inds in pos_inds_list))
        num_total_neg = sum((inds.numel() for inds in neg_inds_list))
        return (
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_pos,
            num_total_neg,
        )

    def loss_single(
        self,
        cls_scores,
        bbox_preds,
        gt_bboxes_list,
        gt_labels_list,
    ):
        num_imgs = cls_scores.size(0)
        cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
        bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
        cls_reg_targets = self.get_targets(
            cls_scores_list,
            bbox_preds_list,
            gt_bboxes_list,
            gt_labels_list,
        )
        (
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_pos,
            num_total_neg,
        ) = cls_reg_targets
        labels = torch.cat(labels_list, 0)
        label_weights = torch.cat(label_weights_list, 0)
        bbox_targets = torch.cat(bbox_targets_list, 0)
        bbox_weights = torch.cat(bbox_weights_list, 0)

        # classification loss
        cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
        # construct weighted avg_factor to match with the official DETR repo
        cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
        if self.sync_cls_avg_factor:
            cls_avg_factor = reduce_mean(cls_scores.new_tensor([cls_avg_factor]))

        cls_avg_factor = max(cls_avg_factor, 1)
        loss_cls = self.loss_cls(
            cls_scores, labels, label_weights, avg_factor=cls_avg_factor
        )

        # Compute the average number of gt boxes accross all gpus, for
        # normalization purposes
        num_total_pos = loss_cls.new_tensor([num_total_pos])
        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()

        # regression L1 loss
        bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
        normalized_bbox_targets = normalize_bbox(bbox_targets, self.pc_range)
        isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
        bbox_weights = bbox_weights * self.code_weights

        loss_bbox = self.loss_bbox(
            bbox_preds[isnotnan, :10],
            normalized_bbox_targets[isnotnan, :10],
            bbox_weights[isnotnan, :10],
            avg_factor=num_total_pos,
        )

        loss_cls = torch.nan_to_num(loss_cls)
        loss_bbox = torch.nan_to_num(loss_bbox)
        return loss_cls, loss_bbox

    def dn_loss_single(
        self, cls_scores, bbox_preds, known_bboxs, known_labels, num_total_pos=None
    ):

        # classification loss
        cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
        # construct weighted avg_factor to match with the official DETR repo
        cls_avg_factor = (
            num_total_pos * 3.14159 / 6 * self.split * self.split * self.split
        )  ### positive rate
        if self.sync_cls_avg_factor:
            cls_avg_factor = reduce_mean(cls_scores.new_tensor([cls_avg_factor]))
        bbox_weights = torch.ones_like(bbox_preds)
        label_weights = torch.ones_like(known_labels)
        cls_avg_factor = max(cls_avg_factor, 1)
        loss_cls = self.loss_cls(
            cls_scores, known_labels.long(), label_weights, avg_factor=cls_avg_factor
        )

        # Compute the average number of gt boxes accross all gpus, for
        # normalization purposes
        num_total_pos = loss_cls.new_tensor([num_total_pos])
        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()

        # regression L1 loss
        bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
        normalized_bbox_targets = normalize_bbox(known_bboxs, self.pc_range)
        isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)

        bbox_weights = bbox_weights * self.code_weights

        loss_bbox = self.loss_bbox(
            bbox_preds[isnotnan, :10],
            normalized_bbox_targets[isnotnan, :10],
            bbox_weights[isnotnan, :10],
            avg_factor=num_total_pos,
        )

        loss_cls = torch.nan_to_num(loss_cls)
        loss_bbox = torch.nan_to_num(loss_bbox)

        return self.dn_weight * loss_cls, self.dn_weight * loss_bbox

    def fake_loss(self, deprecated_reference_points, dtype, device):
        N = len(deprecated_reference_points)
        loss = torch.tensor(0, dtype=dtype, device=device)
        if N > 0:
            for i in range(N):
                loss += 0.0 * deprecated_reference_points[i].mean()
        return loss

    @force_fp32(apply_to=("preds_dicts"))
    def loss(self, gt_bboxes_list, gt_labels_list, preds_dicts):

        all_cls_scores = preds_dicts["all_cls_scores"]
        all_bbox_preds = preds_dicts["all_bbox_preds"]

        dtype = all_cls_scores.dtype
        device = all_cls_scores.device

        num_dec_layers = len(all_cls_scores)
        device = gt_labels_list[0].device
        gt_bboxes_list = [
            torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), dim=1).to(
                device
            )
            for gt_bboxes in gt_bboxes_list
        ]

        all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
        all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]

        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_gt_bboxes_list,
            all_gt_labels_list,
        )

        fake_loss = self.fake_loss(self.deprecated_reference_points, dtype, device)

        loss_dict = dict()

        # loss_dict['size_loss'] = size_loss
        # loss from the last decoder layer
        loss_dict["loss_cls"] = losses_cls[-1]
        loss_dict["loss_bbox"] = losses_bbox[-1]
        loss_dict["fake_loss"] = fake_loss

        # loss from other decoder layers
        num_dec_layer = 0
        for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]):
            loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
            loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i
            num_dec_layer += 1

        if preds_dicts["dn_mask_dict"] is not None:
            (
                known_labels,
                known_bboxs,
                output_known_class,
                output_known_coord,
                num_tgt,
            ) = self.prepare_for_loss(preds_dicts["dn_mask_dict"])
            all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)]
            all_known_labels_list = [known_labels for _ in range(num_dec_layers)]
            all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)]

            dn_losses_cls, dn_losses_bbox = multi_apply(
                self.dn_loss_single,
                output_known_class,
                output_known_coord,
                all_known_bboxs_list,
                all_known_labels_list,
                all_num_tgts_list,
            )
            loss_dict["dn_loss_cls"] = dn_losses_cls[-1]
            loss_dict["dn_loss_bbox"] = dn_losses_bbox[-1]
            num_dec_layer = 0
            for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):
                loss_dict[f"d{num_dec_layer}.dn_loss_cls"] = loss_cls_i
                loss_dict[f"d{num_dec_layer}.dn_loss_bbox"] = loss_bbox_i
                num_dec_layer += 1

        elif self.with_dn:
            dn_losses_cls, dn_losses_bbox = multi_apply(
                self.loss_single,
                all_cls_scores,
                all_bbox_preds,
                all_gt_bboxes_list,
                all_gt_labels_list,
            )
            loss_dict["dn_loss_cls"] = dn_losses_cls[-1].detach()
            loss_dict["dn_loss_bbox"] = dn_losses_bbox[-1].detach()
            num_dec_layer = 0
            for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):
                loss_dict[f"d{num_dec_layer}.dn_loss_cls"] = loss_cls_i.detach()
                loss_dict[f"d{num_dec_layer}.dn_loss_bbox"] = loss_bbox_i.detach()
                num_dec_layer += 1

        return loss_dict

    @force_fp32(apply_to=("preds_dicts"))
    def get_bboxes(self, preds_dicts, img_metas, rescale=False):

        preds_dicts = self.bbox_coder.decode(preds_dicts)
        num_samples = len(preds_dicts)

        ret_list = []
        for i in range(num_samples):
            preds = preds_dicts[i]
            bboxes = preds["bboxes"]
            bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
            bboxes = img_metas[i]["box_type_3d"](bboxes, bboxes.size(-1))
            scores = preds["scores"]
            labels = preds["labels"]
            ret_list.append([bboxes, scores, labels])
        return ret_list
