from typing import List, Optional, Tuple, Union
import warnings

import numpy as np
import torch
import torch.nn as nn

from mmcv.cnn.bricks.registry import (
    ATTENTION,
    PLUGIN_LAYERS,
    POSITIONAL_ENCODING,
    FEEDFORWARD_NETWORK,
    NORM_LAYERS,
)
from mmcv.runner import BaseModule, force_fp32
from mmcv.utils import build_from_cfg
from mmdet.core.bbox.builder import BBOX_SAMPLERS
from mmdet.core.bbox.builder import BBOX_CODERS
from mmdet.models import HEADS, LOSSES
from mmdet.core import reduce_mean

from ..blocks import DeformableFeatureAggregation as DFG
from ..MapExpert import *

__all__ = ["Sparse4DHead"]


@HEADS.register_module()
class Sparse4DHead(BaseModule):
    def __init__(
        self,
        instance_bank: dict,
        anchor_encoder: dict,
        anchor_encoder_map: dict,
        graph_model: dict,
        norm_layer: dict,
        ffn: dict,
        deformable_model: dict,
        refine_layer: dict,
        num_decoder: int = 6,
        num_single_frame_decoder: int = -1,
        temp_graph_model: dict = None,
        loss_cls: dict = None,
        loss_reg: dict = None,
        decoder: dict = None,
        sampler: dict = None,
        gt_cls_key: str = "gt_labels_3d",
        gt_reg_key: str = "gt_bboxes_3d",
        gt_id_key: str = "instance_id",
        with_instance_id: bool = True,
        task_prefix: str = 'det',
        reg_weights: List = None,
        # map
        loss_cls_map: dict = None,
        loss_reg_map: dict = None,
        decoder_map: dict = None,
        sampler_map: dict = None,
        gt_cls_key_map: str = "gt_map_labels",
        gt_reg_key_map: str = "gt_map_pts",
        gt_id_key_map: str = "map_instance_id",
        with_instance_id_map: bool = True,
        task_prefix_map: str = 'map',
        reg_weights_map: List = None,

        operation_order: Optional[List[str]] = None,
        cls_threshold_to_reg: float = -1,
        dn_loss_weight: float = 5.0,
        decouple_attn: bool = True,
        init_cfg: dict = None,
        **kwargs,
    ):
        super(Sparse4DHead, self).__init__(init_cfg)
        self.num_decoder = num_decoder
        self.num_single_frame_decoder = num_single_frame_decoder
        self.gt_cls_key = gt_cls_key
        self.gt_reg_key = gt_reg_key
        self.gt_id_key = gt_id_key
        self.with_instance_id = with_instance_id
        # self.with_instance_id_map = with_instance_id_map
        self.task_prefix = task_prefix
        self.cls_threshold_to_reg = cls_threshold_to_reg
        self.dn_loss_weight = dn_loss_weight
        self.decouple_attn = decouple_attn

        if reg_weights is None:
            self.reg_weights = [1.0] * 10
        else:
            self.reg_weights = reg_weights
        # import pdb; pdb.set_trace()
        #map
        self.gt_cls_key_map = gt_cls_key_map
        self.gt_reg_key_map = gt_reg_key_map
        self.gt_id_key_map = gt_id_key_map
        self.with_instance_id_map = with_instance_id_map
        self.task_prefix_map = task_prefix_map
        # self.cls_threshold_to_reg = cls_threshold_to_reg
        # self.dn_loss_weight = dn_loss_weight
        # self.decouple_attn = decouple_attn

        if reg_weights_map is None:
            self.reg_weights_map = [1.0] * 40
        else:
            self.reg_weights_map = reg_weights_map

        if operation_order is None:
            operation_order = [
                "temp_gnn",
                "gnn",
                "norm",
                "deformable",
                "norm",
                "ffn",
                "norm",
                "refine",
            ] * num_decoder
            # delete the 'gnn' and 'norm' layers in the first transformer blocks
            operation_order = operation_order[3:]
        self.operation_order = operation_order

        # =========== build modules ===========
        def build(cfg, registry):
            if cfg is None:
                return None
            return build_from_cfg(cfg, registry)

        self.instance_bank = build(instance_bank, PLUGIN_LAYERS)
        self.anchor_encoder = build(anchor_encoder, POSITIONAL_ENCODING)
        # self.anchor_encoder_map = build(anchor_encoder_map, POSITIONAL_ENCODING)
        self.sampler = build(sampler, BBOX_SAMPLERS)
        self.decoder = build(decoder, BBOX_CODERS)
        self.loss_cls = build(loss_cls, LOSSES)
        self.loss_reg = build(loss_reg, LOSSES)
        # map
        self.anchor_encoder_map = build(anchor_encoder_map, POSITIONAL_ENCODING)
        self.sampler_map = build(sampler_map, BBOX_SAMPLERS)
        self.decoder_map = build(decoder_map, BBOX_CODERS)
        self.loss_cls_map = build(loss_cls_map, LOSSES)
        self.loss_reg_map = build(loss_reg_map, LOSSES)

        self.op_config_map = {
            "temp_gnn": [temp_graph_model, ATTENTION],
            "gnn": [graph_model, ATTENTION],
            "norm": [norm_layer, NORM_LAYERS],
            "ffn": [ffn, FEEDFORWARD_NETWORK],
            "deformable": [deformable_model, ATTENTION],
            "refine": [refine_layer, PLUGIN_LAYERS],
        }
        self.layers = nn.ModuleList(
            [
                build(*self.op_config_map.get(op, [None, None]))
                for op in self.operation_order
            ]
        )
        # expert we set  6 expert to replace 6 ffn
        self.expert1 = SparseMoeBlock()
        def expert_init(m):
            
            if isinstance(m, nn.Linear):
                # trunc_normal_(m.weight, std=.02)
                m.weight.data.normal_(mean=0.0, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        self.expert1.apply(expert_init)

        self.expert2 = SparseMoeBlock()
        def expert_init(m):
            
            if isinstance(m, nn.Linear):
                # trunc_normal_(m.weight, std=.02)
                m.weight.data.normal_(mean=0.0, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        self.expert2.apply(expert_init)

        self.expert3 = SparseMoeBlock()
        def expert_init(m):
            
            if isinstance(m, nn.Linear):
                # trunc_normal_(m.weight, std=.02)
                m.weight.data.normal_(mean=0.0, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        self.expert3.apply(expert_init)

        self.expert4 = SparseMoeBlock()
        def expert_init(m):
            
            if isinstance(m, nn.Linear):
                # trunc_normal_(m.weight, std=.02)
                m.weight.data.normal_(mean=0.0, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        self.expert4.apply(expert_init)
        
        self.expert5 = SparseMoeBlock()
        def expert_init(m):
            
            if isinstance(m, nn.Linear):
                # trunc_normal_(m.weight, std=.02)
                m.weight.data.normal_(mean=0.0, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        self.expert5.apply(expert_init)

        self.expert6 = SparseMoeBlock()
        def expert_init(m):
            
            if isinstance(m, nn.Linear):
                # trunc_normal_(m.weight, std=.02)
                m.weight.data.normal_(mean=0.0, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        self.expert6.apply(expert_init)


        self.embed_dims = self.instance_bank.embed_dims
        if self.decouple_attn:
            self.fc_before = nn.Linear(
                self.embed_dims, self.embed_dims * 2, bias=False
            )
            self.fc_after = nn.Linear(
                self.embed_dims * 2, self.embed_dims, bias=False
            )
        else:
            self.fc_before = nn.Identity()
            self.fc_after = nn.Identity()

    def init_weights(self):
        for i, op in enumerate(self.operation_order):
            if self.layers[i] is None:
                continue
            elif op != "refine":
                for p in self.layers[i].parameters():
                    if p.dim() > 1:
                        nn.init.xavier_uniform_(p)
        for m in self.modules():
            if hasattr(m, "init_weight"):
                m.init_weight()

    def graph_model(
        self,
        index,
        query,
        key=None,
        value=None,
        query_pos=None,
        key_pos=None,
        **kwargs,
    ):
        if self.decouple_attn:
            query = torch.cat([query, query_pos], dim=-1)
            if key is not None:
                key = torch.cat([key, key_pos], dim=-1)
            query_pos, key_pos = None, None
        if value is not None:
            value = self.fc_before(value)
        return self.fc_after(
            self.layers[index](
                query,
                key,
                value,
                query_pos=query_pos,
                key_pos=key_pos,
                **kwargs,
            )
        )

    def forward(
        self,
        feature_maps: Union[torch.Tensor, List],
        metas: dict,
    ):
        # import pdb; pdb.set_trace()
        # clear clips
        # if metas['prev'] == -1:
        #      self.instance_bank.reset()

        if isinstance(feature_maps, torch.Tensor):
            feature_maps = [feature_maps]
        batch_size = feature_maps[0].shape[0]

        # ========= get instance info ============
        if (
            self.sampler.dn_metas is not None
            and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size
        ):
            self.sampler.dn_metas = None
        (
            instance_feature,
            anchor,
            temp_instance_feature, #last frame instance
            temp_anchor,
            time_interval,
        ) = self.instance_bank.get(
            batch_size, metas, dn_metas=self.sampler.dn_metas
        )

        (
            instance_feature_map,
            anchor_map,
            temp_instance_feature_map, #last frame instance
            temp_anchor_map,
            time_interval_map,
        ) = self.instance_bank.get_map(
            batch_size, metas, dn_metas=self.sampler.dn_metas
        )

        # ========= prepare for denosing training ============
        # 1. get dn metas: noisy-anchors and corresponding GT
        # 2. concat learnable instances and noisy instances
        # 3. get attention mask
        attn_mask = None
        dn_metas = None
        temp_dn_reg_target = None
        if self.training and hasattr(self.sampler, "get_dn_anchors"):
            if self.gt_id_key in metas["img_metas"][0]:
                gt_instance_id = [
                    torch.from_numpy(x[self.gt_id_key]).cuda()
                    for x in metas["img_metas"]
                ]
            else:
                gt_instance_id = None
            dn_metas = self.sampler.get_dn_anchors(
                metas[self.gt_cls_key],
                metas[self.gt_reg_key],
                gt_instance_id,
            )
            # for map
            if self.gt_id_key_map in metas["img_metas"][0]:
                gt_instance_id_map = [
                    # torch.from_numpy(x[self.gt_id_key_map]).cuda()
                    x[self.gt_id_key_map]
                    for x in metas["img_metas"]
                ]
            else:
                gt_instance_id_map = None
            # dn_metas = self.sampler.get_dn_anchors(
            #     metas[self.gt_cls_key],
            #     metas[self.gt_reg_key],
            #     gt_instance_id,
            # )
            dn_metas = None

        # if dn_metas is not None:
        #     (
        #         dn_anchor,
        #         dn_reg_target,
        #         dn_cls_target,
        #         dn_attn_mask,
        #         valid_mask,
        #         dn_id_target,
        #     ) = dn_metas
        #     num_dn_anchor = dn_anchor.shape[1]
        #     if dn_anchor.shape[-1] != anchor.shape[-1]:
        #         remain_state_dims = anchor.shape[-1] - dn_anchor.shape[-1]
        #         dn_anchor = torch.cat(
        #             [
        #                 dn_anchor,
        #                 dn_anchor.new_zeros(
        #                     batch_size, num_dn_anchor, remain_state_dims
        #                 ),
        #             ],
        #             dim=-1,
        #         )
        #     anchor = torch.cat([anchor, dn_anchor], dim=1)
        #     instance_feature = torch.cat(
        #         [
        #             instance_feature,
        #             instance_feature.new_zeros(
        #                 batch_size, num_dn_anchor, instance_feature.shape[-1]
        #             ),
        #         ],
        #         dim=1,
        #     )
        #     num_instance = instance_feature.shape[1]
        #     num_free_instance = num_instance - num_dn_anchor
        #     attn_mask = anchor.new_ones(
        #         (num_instance, num_instance), dtype=torch.bool
        #     )
        #     attn_mask[:num_free_instance, :num_free_instance] = False
        #     attn_mask[num_free_instance:, num_free_instance:] = dn_attn_mask

        anchor_embed = self.anchor_encoder(anchor)
        if temp_anchor is not None:
            temp_anchor_embed = self.anchor_encoder(temp_anchor)
        else:
            temp_anchor_embed = None
        
        anchor_embed_map = self.anchor_encoder_map(anchor_map)
        if temp_anchor_map is not None:
            temp_anchor_embed_map = self.anchor_encoder_map(temp_anchor_map)
        else:
            temp_anchor_embed_map = None

        # =================== forward the layers ====================
        prediction = []
        classification = []
        quality = []
        prediction_map = []
        classification_map = []
        quality_map = []

        # concat feature and anchor_embed
        # import pdb; pdb.set_trace()
        instance_feature = torch.cat([instance_feature, instance_feature_map], dim=-2)
        anchor_embed = torch.cat([anchor_embed, anchor_embed_map], dim=-2)
        if temp_instance_feature is not None or temp_anchor_embed is not None:
            temp_instance_feature= torch.cat([temp_instance_feature, temp_instance_feature_map], dim=-2)
            
            temp_anchor_embed_all = torch.cat([temp_anchor_embed, temp_anchor_embed_map], dim=-2)
        else:
            temp_instance_feature=None
            temp_anchor_embed_all =None
        # expert layer
        all_router_logits = ()

        for i, op in enumerate(self.operation_order):
            # print("---------------", i,op, instance_feature.shape[-1])
            if self.layers[i] is None:
                continue
            elif op == "temp_gnn":
                # print("-----", i, temp_instance_feature==None)
                # import pdb; pdb.set_trace()
                instance_feature = self.graph_model(
                    i,
                    instance_feature,
                    temp_instance_feature,
                    temp_instance_feature,
                    query_pos=anchor_embed,
                    key_pos=temp_anchor_embed_all,
                    attn_mask=attn_mask
                    if temp_instance_feature is None
                    else None,
                )
            elif op == "gnn":
                instance_feature = self.graph_model(
                    i,
                    instance_feature,
                    value=instance_feature,
                    query_pos=anchor_embed,
                    attn_mask=attn_mask,
                )
            # elif op == "norm" or op == "ffn":
            #     # import pdb; pdb.set_trace()
            #     instance_feature = self.layers[i](instance_feature)
            elif op == "norm":
                # import pdb; pdb.set_trace()
                instance_feature = self.layers[i](instance_feature)
            elif op == "ffn":
                instance_feature_ffn = self.layers[i](instance_feature) # (bs, 1000, 512)
                # expert 
                # import pdb; pdb.set_trace()
                if i == 1:
                    expert = self.expert1
                elif i == 8:
                    expert = self.expert2
                elif i == 15:
                    expert = self.expert3
                elif i == 22:
                    expert = self.expert4
                elif i == 29:
                    expert = self.expert5
                elif i == 36:
                    expert = self.expert6
                else:
                    print("-------expert no is not correct-----",i)
                    break

                # query is [sequence, bs, dim]
                # but expert need [bs, sequence, dim]
                # thus transform query to [bs, sequence,  dim]
                # query = torch.transpose(query, 0,1).contiguous()
                # query = query.permute(1, 0, 2).contiguous() 
                
                # hidden-states [bs, sequence, dim],   router_logits [bs*sequence, 8]
                hidden_states, router_logits = expert(instance_feature)
                instance_feature = hidden_states + instance_feature_ffn #instance_feature + hidden_states
                # ffn_index += 1 
                # transform query back to [sequence,  bs,  dim]
                # query = torch.transpose(query, 0,1).contiguous()   
                # query = query.permute(1, 0, 2).contiguous() # query[sequence, bs, dim], router_logits[bs*sequence, 8]


                all_router_logits += (router_logits,)

            elif op == "deformable":
                # import pdb; pdb.set_trace()
                instance_feature = self.layers[i](
                    instance_feature,
                    anchor,
                    anchor_map,
                    anchor_embed,
                    feature_maps,
                    metas,
                )
            elif op == "refine":
                # print(anchor.size(), anchor.shape[1])
                
                anchor, cls, qt, anchor_map, cls_map, qt_map = self.layers[i](
                    instance_feature,
                    anchor,
                    anchor_map,
                    anchor_embed,
                    time_interval=time_interval,
                    return_cls=True,
                    anchor_size=anchor.shape[1], #
                )
                prediction.append(anchor)
                classification.append(cls)
                quality.append(qt)
                # map
                prediction_map.append(anchor_map)
                classification_map.append(cls_map)
                quality_map.append(qt_map)
                if len(prediction) == self.num_single_frame_decoder:
                    instance_feature_det, anchor = self.instance_bank.update(
                        instance_feature[:,:anchor.shape[1],:], anchor, cls
                    )
                    # map
                    instance_feature_map, anchor_map = self.instance_bank.update_map(
                        instance_feature[:,anchor.shape[1]:,:], anchor_map, cls_map
                    )
                    instance_feature = torch.cat([instance_feature_det, instance_feature_map],dim=-2)
                    # if (
                    #     dn_metas is not None
                    #     and self.sampler.num_temp_dn_groups > 0
                    #     and dn_id_target is not None
                    # ):
                    #     (
                    #         instance_feature,
                    #         anchor,
                    #         temp_dn_reg_target,
                    #         temp_dn_cls_target,
                    #         temp_valid_mask,
                    #         dn_id_target,
                    #     ) = self.sampler.update_dn(
                    #         instance_feature,
                    #         anchor,
                    #         dn_reg_target,
                    #         dn_cls_target,
                    #         valid_mask,
                    #         dn_id_target,
                    #         self.instance_bank.num_anchor,
                    #         self.instance_bank.mask,
                    #     )
                anchor_embed_det = self.anchor_encoder(anchor)#det
                anchor_embed_map = self.anchor_encoder_map(anchor_map)
                anchor_embed = torch.cat([anchor_embed_det, anchor_embed_map], dim=-2)
                
                if (
                    len(prediction) > self.num_single_frame_decoder
                    and temp_anchor_embed is not None
                ):
                    # import pdb; pdb.set_trace()
                    temp_anchor_embed = anchor_embed_det[
                        :, : self.instance_bank.num_temp_instances
                    ]
                    temp_anchor_embed_map = anchor_embed_map[
                        :, : self.instance_bank.num_temp_instances_map,:
                    ]
                    temp_anchor_embed_all = torch.cat([temp_anchor_embed, temp_anchor_embed_map], dim=-2)
                    # print("-----", i, anchor_embed_map.shape, anchor_embed_det.shape,temp_anchor_embed_map.shape, temp_anchor_embed.shape)
                    
            else:
                raise NotImplementedError(f"{op} is not supported.")

        output = {}
        output_map ={}
        # import pdb; pdb.set_trace()

        # split predictions of learnable instances and noisy instances
        # if dn_metas is not None:
        #     dn_classification = [
        #         x[:, num_free_instance:] for x in classification
        #     ]
        #     classification = [x[:, :num_free_instance] for x in classification]
        #     dn_prediction = [x[:, num_free_instance:] for x in prediction]
        #     prediction = [x[:, :num_free_instance] for x in prediction]
        #     quality = [
        #         x[:, :num_free_instance] if x is not None else None
        #         for x in quality
        #     ]
        #     output.update(
        #         {
        #             "dn_prediction": dn_prediction,
        #             "dn_classification": dn_classification,
        #             "dn_reg_target": dn_reg_target,
        #             "dn_cls_target": dn_cls_target,
        #             "dn_valid_mask": valid_mask,
        #         }
        #     )
        #     if temp_dn_reg_target is not None:
        #         output.update(
        #             {
        #                 "temp_dn_reg_target": temp_dn_reg_target,
        #                 "temp_dn_cls_target": temp_dn_cls_target,
        #                 "temp_dn_valid_mask": temp_valid_mask,
        #                 "dn_id_target": dn_id_target,
        #             }
        #         )
        #         dn_cls_target = temp_dn_cls_target
        #         valid_mask = temp_valid_mask
        #     dn_instance_feature = instance_feature[:, num_free_instance:]
        #     dn_anchor = anchor[:, num_free_instance:]
        #     instance_feature = instance_feature[:, :num_free_instance]
        #     anchor_embed = anchor_embed[:, :num_free_instance]
        #     anchor = anchor[:, :num_free_instance]
        #     cls = cls[:, :num_free_instance]

        #     # cache dn_metas for temporal denoising
        #     self.sampler.cache_dn(
        #         dn_instance_feature,
        #         dn_anchor,
        #         dn_cls_target,
        #         valid_mask,
        #         dn_id_target,
        #     )
        output.update(
            {
                "classification": classification,
                "prediction": prediction,
                "quality": quality,
                "instance_feature": instance_feature[:,:anchor.shape[1],:],
                "anchor_embed": anchor_embed[:,:anchor.shape[1],:],
                #map
                # "classification_map": classification_map,
                # "prediction_map": prediction_map,
                # "quality_ma": quality_map,
                # # "instance_feature_map": instance_feature_map,
                # # "anchor_embed_map": anchor_embed_map,
            }
        )
        output_map.update(
            {
                # "classification": classification,
                # "prediction": prediction,
                # "quality": quality,
                # "instance_feature": instance_feature,
                # "anchor_embed": anchor_embed,
                #map
                "classification_map": classification_map,
                "prediction_map": prediction_map,
                "quality_map": quality_map,
                "instance_feature_map": instance_feature[:,anchor.shape[1]:,:],
                "anchor_embed_map": anchor_embed[:,anchor.shape[1]:,:],
            }
        )

        # cache current instances for temporal modeling
        # import pdb; pdb.set_trace()
        self.instance_bank.cache(
            instance_feature[:,:anchor.shape[1],:], anchor[:,:anchor.shape[1],:], cls, metas, feature_maps
        )
        # self.instance_bank.cache_map(
        #     instance_feature[:,anchor.shape[1]:,:], anchor[:,anchor.shape[1]:,:], cls_map, metas, feature_maps
        # )
        self.instance_bank.cache_map(
            instance_feature[:,anchor.shape[1]:,:], anchor_map, cls_map, metas, feature_maps
        )
        
        if self.with_instance_id:
            instance_id = self.instance_bank.get_instance_id(
                cls, anchor, self.decoder.score_threshold
            )
            output["instance_id"] = instance_id
        if self.with_instance_id_map:
            instance_id_map = self.instance_bank.get_instance_id_map(
                cls_map, anchor_map, self.decoder.score_threshold
            )
            output_map["instance_id_map"] = instance_id_map
        # rooter logits
        output["all_router_logits"] = all_router_logits
        return output, output_map

    @force_fp32(apply_to=("gate_logits"))
    def balance_loss(self, gate_logits):
        loss_dict = {}
        loss_dict['all_router_loss'] = 0.1 * load_balancing_loss_func(gate_logits, 4, 2)
        return loss_dict


    @force_fp32(apply_to=("model_outs"))
    def loss(self, model_outs, data, feature_maps=None):
        # ===================== prediction losses ======================
        cls_scores = model_outs["classification"]
        reg_preds = model_outs["prediction"]
        quality = model_outs["quality"]
        output = {}
        b_loss = self.balance_loss(model_outs["all_router_logits"])
        output.update(b_loss)

        for decoder_idx, (cls, reg, qt) in enumerate(
            zip(cls_scores, reg_preds, quality)
        ):
            reg = reg[..., : len(self.reg_weights)]
            cls_target, reg_target, reg_weights = self.sampler.sample(
                cls,
                reg,
                data[self.gt_cls_key],
                data[self.gt_reg_key],
            )
            reg_target = reg_target[..., : len(self.reg_weights)]
            reg_target_full = reg_target.clone()
            mask = torch.logical_not(torch.all(reg_target == 0, dim=-1))
            mask_valid = mask.clone()

            num_pos = max(
                reduce_mean(torch.sum(mask).to(dtype=reg.dtype)), 1.0
            )
            if self.cls_threshold_to_reg > 0:
                threshold = self.cls_threshold_to_reg
                mask = torch.logical_and(
                    mask, cls.max(dim=-1).values.sigmoid() > threshold
                )

            cls = cls.flatten(end_dim=1)
            cls_target = cls_target.flatten(end_dim=1)
            cls_loss = self.loss_cls(cls, cls_target, avg_factor=num_pos)

            mask = mask.reshape(-1)
            reg_weights = reg_weights * reg.new_tensor(self.reg_weights)
            reg_target = reg_target.flatten(end_dim=1)[mask]
            reg = reg.flatten(end_dim=1)[mask]
            reg_weights = reg_weights.flatten(end_dim=1)[mask]
            reg_target = torch.where(
                reg_target.isnan(), reg.new_tensor(0.0), reg_target
            )
            cls_target = cls_target[mask]
            if qt is not None:
                qt = qt.flatten(end_dim=1)[mask]

            reg_loss = self.loss_reg(
                reg,
                reg_target,
                weight=reg_weights,
                avg_factor=num_pos,
                prefix=f"{self.task_prefix}_",
                suffix=f"_{decoder_idx}",
                quality=qt,
                cls_target=cls_target,
            )

            output[f"{self.task_prefix}_loss_cls_{decoder_idx}"] = cls_loss
            output.update(reg_loss)


            # # print("==---------==", decoder_idx, cls_loss)
            # # print("=============", decoder_idx, reg_loss)
            # cls_loss = cls_loss / 100000
            # for reg_ in reg_loss:
            #     reg_loss[reg_] = reg_loss[reg_] /100000 #

            # # print("^^^^^^^^^==",  reg_loss)
            # output[f"{self.task_prefix}_loss_cls_{decoder_idx}"] = cls_loss
            # output.update(reg_loss)


        if "dn_prediction" not in model_outs:
            return output

        # ===================== denoising losses ======================
        dn_cls_scores = model_outs["dn_classification"]
        dn_reg_preds = model_outs["dn_prediction"]

        (
            dn_valid_mask,
            dn_cls_target,
            dn_reg_target,
            dn_pos_mask,
            reg_weights,
            num_dn_pos,
        ) = self.prepare_for_dn_loss(model_outs)
        for decoder_idx, (cls, reg) in enumerate(
            zip(dn_cls_scores, dn_reg_preds)
        ):
            if (
                "temp_dn_valid_mask" in model_outs
                and decoder_idx == self.num_single_frame_decoder
            ):
                (
                    dn_valid_mask,
                    dn_cls_target,
                    dn_reg_target,
                    dn_pos_mask,
                    reg_weights,
                    num_dn_pos,
                ) = self.prepare_for_dn_loss(model_outs, prefix="temp_")

            cls_loss = self.loss_cls(
                cls.flatten(end_dim=1)[dn_valid_mask],
                dn_cls_target,
                avg_factor=num_dn_pos,
            )
            reg_loss = self.loss_reg(
                reg.flatten(end_dim=1)[dn_valid_mask][dn_pos_mask][
                    ..., : len(self.reg_weights)
                ],
                dn_reg_target,
                avg_factor=num_dn_pos,
                weight=reg_weights,
                prefix=f"{self.task_prefix}_",
                suffix=f"_dn_{decoder_idx}",
            )
            output[f"{self.task_prefix}_loss_cls_dn_{decoder_idx}"] = cls_loss
            output.update(reg_loss)
        return output

    @force_fp32(apply_to=("model_outs"))
    def loss_map(self, model_outs, data, feature_maps=None):
        # import pdb; pdb.set_trace()
        # ===================== prediction losses ======================
        cls_scores = model_outs["classification_map"]
        reg_preds = model_outs["prediction_map"]
        quality = model_outs["quality_map"]
        output = {}
        for decoder_idx, (cls, reg, qt) in enumerate(
            zip(cls_scores, reg_preds, quality)
        ):
            reg = reg[..., : len(self.reg_weights_map)]
            cls_target, reg_target, reg_weights = self.sampler_map.sample(
                cls,
                reg,
                data[self.gt_cls_key_map],
                data[self.gt_reg_key_map],
            )
            reg_target = reg_target[..., : len(self.reg_weights_map)]
            reg_target_full = reg_target.clone()
            mask = torch.logical_not(torch.all(reg_target == 0, dim=-1))
            mask_valid = mask.clone()

            num_pos = max(
                reduce_mean(torch.sum(mask).to(dtype=reg.dtype)), 1.0
            )
            # print("------map position count ===", num_pos)
            if self.cls_threshold_to_reg > 0:
                threshold = self.cls_threshold_to_reg
                mask = torch.logical_and(
                    mask, cls.max(dim=-1).values.sigmoid() > threshold
                )

            cls = cls.flatten(end_dim=1)
            cls_target = cls_target.flatten(end_dim=1)
            cls_loss = self.loss_cls_map(cls, cls_target, avg_factor=num_pos)

            mask = mask.reshape(-1)
            reg_weights = reg_weights * reg.new_tensor(self.reg_weights_map)
            reg_target = reg_target.flatten(end_dim=1)[mask]
            reg = reg.flatten(end_dim=1)[mask]
            reg_weights = reg_weights.flatten(end_dim=1)[mask]
            reg_target = torch.where(
                reg_target.isnan(), reg.new_tensor(0.0), reg_target
            )
            cls_target = cls_target[mask]
            if qt is not None:
                qt = qt.flatten(end_dim=1)[mask]

            reg_loss = self.loss_reg_map(
                reg,
                reg_target,
                weight=reg_weights,
                avg_factor=num_pos,
                prefix=f"{self.task_prefix_map}_",
                suffix=f"_{decoder_idx}",
                quality=qt,
                cls_target=cls_target,
            )

            output[f"{self.task_prefix_map}_loss_cls_{decoder_idx}"] = cls_loss
            output.update(reg_loss)

        if "dn_prediction" not in model_outs:
            return output

        # # ===================== denoising losses ======================
        # dn_cls_scores = model_outs["dn_classification"]
        # dn_reg_preds = model_outs["dn_prediction"]

        # (
        #     dn_valid_mask,
        #     dn_cls_target,
        #     dn_reg_target,
        #     dn_pos_mask,
        #     reg_weights,
        #     num_dn_pos,
        # ) = self.prepare_for_dn_loss(model_outs)
        # for decoder_idx, (cls, reg) in enumerate(
        #     zip(dn_cls_scores, dn_reg_preds)
        # ):
        #     if (
        #         "temp_dn_valid_mask" in model_outs
        #         and decoder_idx == self.num_single_frame_decoder
        #     ):
        #         (
        #             dn_valid_mask,
        #             dn_cls_target,
        #             dn_reg_target,
        #             dn_pos_mask,
        #             reg_weights,
        #             num_dn_pos,
        #         ) = self.prepare_for_dn_loss(model_outs, prefix="temp_")

        #     cls_loss = self.loss_cls(
        #         cls.flatten(end_dim=1)[dn_valid_mask],
        #         dn_cls_target,
        #         avg_factor=num_dn_pos,
        #     )
        #     reg_loss = self.loss_reg(
        #         reg.flatten(end_dim=1)[dn_valid_mask][dn_pos_mask][
        #             ..., : len(self.reg_weights)
        #         ],
        #         dn_reg_target,
        #         avg_factor=num_dn_pos,
        #         weight=reg_weights,
        #         prefix=f"{self.task_prefix}_",
        #         suffix=f"_dn_{decoder_idx}",
        #     )
        #     output[f"{self.task_prefix}_loss_cls_dn_{decoder_idx}"] = cls_loss
        #     output.update(reg_loss)
        # return output

    def prepare_for_dn_loss(self, model_outs, prefix=""):
        dn_valid_mask = model_outs[f"{prefix}dn_valid_mask"].flatten(end_dim=1)
        dn_cls_target = model_outs[f"{prefix}dn_cls_target"].flatten(
            end_dim=1
        )[dn_valid_mask]
        dn_reg_target = model_outs[f"{prefix}dn_reg_target"].flatten(
            end_dim=1
        )[dn_valid_mask][..., : len(self.reg_weights)]
        dn_pos_mask = dn_cls_target >= 0
        dn_reg_target = dn_reg_target[dn_pos_mask]
        reg_weights = dn_reg_target.new_tensor(self.reg_weights)[None].tile(
            dn_reg_target.shape[0], 1
        )
        num_dn_pos = max(
            reduce_mean(torch.sum(dn_valid_mask).to(dtype=reg_weights.dtype)),
            1.0,
        )
        return (
            dn_valid_mask,
            dn_cls_target,
            dn_reg_target,
            dn_pos_mask,
            reg_weights,
            num_dn_pos,
        )

    @force_fp32(apply_to=("model_outs"))
    def post_process(self, model_outs, output_idx=-1):
        return self.decoder.decode(
            model_outs["classification"],
            model_outs["prediction"],
            model_outs.get("instance_id"),
            model_outs.get("quality"),
            output_idx=output_idx,
        )

    @force_fp32(apply_to=("model_outs"))
    def post_process_map(self, model_outs, output_idx=-1):
        return self.decoder_map.decode(
            model_outs["classification_map"],
            model_outs["prediction_map"],
            model_outs.get("instance_id_map"),
            model_outs.get("quality_map"),
            output_idx=output_idx,
        )
