# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------
# DINO
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR Transformer class.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
import math
import random
from typing import Optional

import torch
from torch import nn
from torch import Tensor

from .ops.modules import MSDeformAttn
from scene.transformer_decoder.encoder_cross_attn import multi_view_cross_attn, CrossAttnBlock
from utils.dino_utils import _get_activation_fn
from utils.dino_utils import gen_sineembed_for_3Dposition
from utils.dino_utils import inverse_sigmoid
from utils.dino_utils import MLP
from utils.general_utils import quaternion_raw_multiply
from utils.geometry_utils import ReferencePointProjection
from utils.group_attention import down_sample_query

class DINODecoder(nn.Module):
    """ This is the Cross-Attention Detector module that performs object detection """

    def __init__(self, cfg, transformer, num_queries, backbone_out_channels,
                 query_dim=4,
                 num_feature_levels=4,
                 nheads=8,
                 ):
        """ Initializes the model.
        Parameters:
            transformer: torch module of the transformer architecture.
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         Conditional DETR can detect in a single image.
        """
        super().__init__()
        self.cfg = cfg
        self.num_queries = num_queries
        self.transformer = transformer
        self.hidden_dim = hidden_dim = transformer.d_model
        self.num_feature_levels = num_feature_levels
        self.nheads = nheads

        # setting query dim
        self.query_dim = query_dim
        assert query_dim == 4

        # prepare input projection layers
        if num_feature_levels > 1:
            num_backbone_outs = len(backbone_out_channels)
            self.num_backbone_outs = num_backbone_outs
            input_proj_list = []
            bb_kernel = 1
            bb_stride = 1

            fused_kernel = 1
            fused_stride = 1

            for _ in range(num_backbone_outs):
                in_channels = backbone_out_channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=bb_kernel, stride=bb_stride),
                    nn.GroupNorm(32, hidden_dim),
                ))

            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=fused_kernel, stride=fused_stride),
                    nn.GroupNorm(32, hidden_dim),
                ))
                # in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(in_channels[-1], hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                )])

        self._reset_parameters()

    def _reset_parameters(self):
        # init input_proj
        for proj in self.input_proj:
            nn.init.xavier_uniform_(proj[0].weight, gain=1)
            nn.init.constant_(proj[0].bias, 0)

    def forward(self, 
                features: Tensor, 
                h: int, 
                w: int, 
                intrinsics: Tensor, 
                const_offset: Tensor = None,
                source_cameras_view_to_world: Tensor = None,
                cam_emb: Tensor = None,
                point_cloud_init: Tensor = None,
                stage1_network_outputs: Tensor = None):
        
        B, bb_dim, _, _ = features[0].shape
        srcs = []
        masks = []
        poss = None

        for l in range(self.num_backbone_outs):
            src = features[l]
            src = self.input_proj[l](src)
            srcs.append(src)
            mask = torch.zeros_like(src[:, 0, :, :]).bool()
            masks.append(mask)
        if self.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)

            for l in range(_len_srcs, self.num_feature_levels):
                features[l] = features[l].permute(0, 2, 1).reshape(B, bb_dim, h, w)
                # if l == _len_srcs:
                src = self.input_proj[l](features[l])
                # else:
                #     src = self.input_proj[l](srcs[-1])
                mask = torch.zeros_like(src[:, 0, :, :]).bool()
                srcs.append(src)
                masks.append(mask)

        hs, reference, gs_preds = self.transformer(srcs,
                                                   masks,
                                                   poss, 
                                                   intrinsics=intrinsics, 
                                                   const_offset=const_offset,
                                                   source_cameras_view_to_world=source_cameras_view_to_world,
                                                   cam_emb=cam_emb,
                                                   point_cloud_init=point_cloud_init,
                                                   stage1_network_outputs=stage1_network_outputs)


        return gs_preds


class DeformableTransformer(nn.Module):

    def __init__(self, cfg, d_model=256, nhead=8,
                 num_queries=300,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048, dropout=0.0,
                 activation="relu", query_dim=4,
                 modulate_hw_attn=False,
                 deformable_decoder=False,
                 num_feature_levels=1,
                 dec_n_points=4,
                 # evo of #anchors
                 dec_layer_number=None,
                 rm_self_attn_layers=None,
                 key_aware_type=None,
                 # for detach
                 rm_detach=None,
                 decoder_sa_type='ca',
                 module_seq=['sa', 'ca', 'ffn'],
                 out_dim=None,
                 stage1_queries=0,
                 ):
        super().__init__()

        self.cfg = cfg
        self.num_feature_levels = num_feature_levels
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.deformable_decoder = deformable_decoder
        self.num_queries = num_queries
        assert query_dim == 4

        self.decoder_sa_type = decoder_sa_type
        assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
        self.cross_attn_fn = nn.ModuleList([CrossAttnBlock(d_model=d_model, 
                                                                    nhead=1
                                                    ) for i in range(num_feature_levels)])
        # choose decoder layer type
        if deformable_decoder:
            decoder_layer = DeformableTransformerDecoderLayer(cfg, d_model, dim_feedforward,
                                                              dropout, activation,
                                                              num_feature_levels, nhead, dec_n_points,
                                                              key_aware_type=key_aware_type,
                                                              decoder_sa_type=decoder_sa_type,
                                                              module_seq=module_seq)

        else:
            raise NotImplementedError

        decoder_norm = None
        self.decoder = TransformerDecoder(cfg, decoder_layer, num_decoder_layers,
                                          d_model=d_model, query_dim=query_dim,
                                          modulate_hw_attn=modulate_hw_attn,
                                          num_feature_levels=num_feature_levels,
                                          deformable_decoder=deformable_decoder,
                                          dec_layer_number=dec_layer_number,
                                          out_dim=out_dim,
                                          dec_n_points=dec_n_points
                                          )

        self.d_model = d_model
        self.nhead = nhead
        self.dec_layers = num_decoder_layers
        self.num_queries = num_queries  # useful for single stage model only

        if num_feature_levels > 1:
            self.level_embed = None


        self.tgt_embed = nn.Embedding(stage1_queries, d_model)
        nn.init.normal_(self.tgt_embed.weight.data)

        self.init_ref_points(num_queries)  # init self.refpoint_embed

        # evolution of anchors
        self.dec_layer_number = dec_layer_number
        if dec_layer_number is not None:
            assert dec_layer_number[0] == num_queries, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})"

        # self._reset_parameters(except_dim=sum(out_dim))

        self.rm_self_attn_layers = rm_self_attn_layers
        if rm_self_attn_layers is not None:
            print("Removing the self-attn in {} decoder layers".format(rm_self_attn_layers))
            for lid, dec_layer in enumerate(self.decoder.layers):
                if lid in rm_self_attn_layers:
                    dec_layer.rm_self_attn_modules()

        self.rm_detach = rm_detach
        if self.rm_detach:
            assert isinstance(rm_detach, list)
            assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
        self.decoder.rm_detach = rm_detach

    def _reset_parameters(self, except_dim):
        for p in self.parameters():
            if p.dim() > 1 and p.shape[0] != except_dim:
                nn.init.xavier_uniform_(p)
        for m in self.modules():
            if isinstance(m, MSDeformAttn):
                m._reset_parameters()
        if self.num_feature_levels > 1 and self.level_embed is not None:
            nn.init.normal_(self.level_embed)

    def get_valid_ratio(self, mask):
        _, H, W = mask.shape
        valid_H = torch.sum(~mask[:, :, 0], 1)
        valid_W = torch.sum(~mask[:, 0, :], 1)
        valid_ratio_h = valid_H.float() / H
        valid_ratio_w = valid_W.float() / W
        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
        return valid_ratio

    def init_ref_points(self, use_num_queries):
        num_ref_point = 3
        self.refpoint_embed = nn.Embedding(use_num_queries, num_ref_point)
        

    def forward(self, 
                srcs, 
                masks, 
                attn_mask=None, 
                intrinsics=None, 
                const_offset=None,
                source_cameras_view_to_world=None,
                cam_emb=None,
                point_cloud_init=None,
                stage1_network_outputs=None):
        """
        Input:
            - srcs: List of multi features [bs, ci, hi, wi]
            - query_fea: feature from stage 1 used for initialize queries [bs, hw, d_query]
            - masks: List of multi masks [bs, hi, wi]
            - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
            - tgt: [bs, num_dn, d_model]. None in infer

        """

        N_views = source_cameras_view_to_world.shape[1]

        ### do cross attention on image features in different views
        if self.cfg.data.input_images > 1:
            num_views = self.cfg.data.input_images
            B = int(srcs[0].shape[0] / num_views)
            attn_num_splits = self.cfg.model.attn_num_splits
            cross_features = []
            for i in range(len(srcs)):
                Bn, D, H, W = srcs[i].size() 
                src = multi_view_cross_attn(self.cross_attn_fn[i], srcs[i].view(Bn, D, H * W).permute(0, 2, 1), num_views, attn_num_splits)
                src = src.view(Bn, H, W, D).permute(0, 3, 1, 2)
                cross_features.append(src)
            srcs = cross_features
        

        # prepare input for encoder
        src_flatten = []
        mask_flatten = []
        spatial_shapes = []
        for lvl, (src, mask) in enumerate(zip(srcs, masks)):
            bs, c, h, w = src.shape
            spatial_shape = (h, w)
            spatial_shapes.append(spatial_shape)

            src = src.flatten(2).transpose(1, 2)                # bs, hw, c
            mask = mask.flatten(1)                              # bs, hw

            src_flatten.append(src)
            mask_flatten.append(mask)

        src_flatten = torch.cat(src_flatten, 1)    # bs, \sum{hxw}, c
        mask_flatten = torch.cat(mask_flatten, 1)   # bs, \sum{hxw}
        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 

        ### do cross attention on image features in different views

        memory = src_flatten
        tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs//N_views, 1).transpose(0, 1)  # bs, nq, d_model
        
        init_gs = stage1_network_outputs
        refpoint_embed_ = None
        refpoint_embed, tgt = refpoint_embed_, tgt_


        #########################################################
        # End preparing tgt
        # - tgt: bs, NQ, d_model
        # - refpoint_embed(unsigmoid): bs, NQ, d_model
        #########################################################

        #########################################################
        # Begin Decoder
        #########################################################
        hs, references, gs_preds = self.decoder(
            tgt=tgt.transpose(0, 1),  # (nq, bs, d_model)
            memory=memory.transpose(0, 1),
            memory_key_padding_mask=mask_flatten,
            refpoints_unsigmoid=refpoint_embed.transpose(0, 1) if refpoint_embed is not None else None,
            init_gs=init_gs,
            level_start_index=level_start_index,
            spatial_shapes=spatial_shapes,
            valid_ratios=valid_ratios, tgt_mask=attn_mask,
            intrinsics=intrinsics,
            const_offset=const_offset,
            source_cameras_view_to_world=source_cameras_view_to_world,
            cam_emb=cam_emb,
            point_cloud_init=point_cloud_init)
        #########################################################
        # End Decoder
        # hs: n_dec, bs, nq, d_model
        # references: n_dec+1, bs, nq, query_dim
        #########################################################


        return hs, references, gs_preds
       


class TransformerDecoder(nn.Module):

    def __init__(self, cfg, decoder_layer, num_layers,
                 d_model=256, query_dim=4,
                 modulate_hw_attn=False,
                 num_feature_levels=1,
                 deformable_decoder=False,
                 dec_layer_number=None,  # number of queries each layer in decoder
                 dec_layer_dropout_prob=None,
                 out_dim=None,
                 dec_n_points=4

                 ):
        super().__init__()
        if num_layers > 0:
            self.layers = _get_clones(decoder_layer, num_layers, layer_share=False)
        else:
            self.layers = []
        self.cfg = cfg
        self.num_layers = num_layers
        self.hidden_dim = d_model
        self.dec_n_points = dec_n_points
        self.query_dim = query_dim
        assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
        self.num_feature_levels = num_feature_levels

        self.ref3d_proj = ReferencePointProjection(cfg)
        self.ref_point_head = MLP(128 * 3, d_model, d_model, 2)
            
        if not deformable_decoder:
            self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
        else:
            self.query_pos_sine_scale = None

        self.d_model = d_model
        self.modulate_hw_attn = modulate_hw_attn
        self.deformable_decoder = deformable_decoder

        if not deformable_decoder and modulate_hw_attn:
            self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
        else:
            self.ref_anchor_head = None

        self.dec_layer_number = dec_layer_number
        if dec_layer_number is not None:
            assert isinstance(dec_layer_number, list)
            assert len(dec_layer_number) == num_layers

        self.dec_layer_dropout_prob = dec_layer_dropout_prob
        if dec_layer_dropout_prob is not None:
            assert isinstance(dec_layer_dropout_prob, list)
            assert len(dec_layer_dropout_prob) == num_layers
            for i in dec_layer_dropout_prob:
                assert 0.0 <= i <= 1.0

        self.rm_detach = None
        
        out_layers_pos = []
        out_layers_others = []
        for lay in range(self.num_layers):
            out_pos, out_others = self.init_out_head_split_zero_init(out_dim)
            
            out_layers_pos.append(out_pos)
            out_layers_others.append(out_others)

        self.outs_pos = nn.ModuleList(out_layers_pos)
        self.outs_others = nn.ModuleList(out_layers_others)

    def out_head(self, dec_out, out_layer):
        
        B, n_gaussian, dec_dim = dec_out.shape
        sqrt_N = int(math.sqrt(n_gaussian))
        assert sqrt_N == math.sqrt(n_gaussian), "number of 3D gaussians should be perfect square values"
        gs_pred = out_layer(dec_out.permute(0, 2, 1).reshape(B, dec_dim, sqrt_N, sqrt_N))
        
        return gs_pred


    def init_out_head_split_zero_init(self, out_dim):
        
        out_pos = nn.Conv2d(in_channels=self.hidden_dim,
                        out_channels=3,
                        kernel_size=1)
        out_others = nn.Conv2d(in_channels=self.hidden_dim,
                        out_channels=sum(out_dim) - 3,
                        kernel_size=1)
        out_pos.weight.data.zero_()
        out_pos.bias.data.zero_()
        out_others.weight.data.zero_()
        out_others.bias.data.zero_()
        
        return out_pos, out_others


    def forward(self, tgt, 
                memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                refpoints_unsigmoid: Optional[Tensor] = None,  # num_queries, bs, 2
                init_gs: Optional[Tensor] = None,
                # for memory
                level_start_index: Optional[Tensor] = None,  # num_levels
                spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2
                valid_ratios: Optional[Tensor] = None,
                intrinsics: Optional[Tensor] = None,
                const_offset: Optional[Tensor] = None,
                source_cameras_view_to_world: Optional[Tensor] = None,
                cam_emb: Optional[Tensor] = None,
                point_cloud_init: Optional[Tensor] = None
                ):
        """
        Input:
            - tgt: nq, bs, d_model
            - memory: hw, bs, d_model
            - pos: hw, bs, d_model
            - refpoints_unsigmoid: nq, bs, gaussian_channel
            - valid_ratios/spatial_shapes: bs, nlevel, 2
        """
        output = tgt
        itm_gs_preds = []
        intermediate = []

        N_views = source_cameras_view_to_world.shape[1] ## bs, N_views, 4, 4

        c2w_view1 = source_cameras_view_to_world[:, 0]
        w2c_view1 = torch.inverse(c2w_view1.float()).to(source_cameras_view_to_world.dtype)

        source_cameras_view_to_view1 = torch.matmul(w2c_view1[:, None], source_cameras_view_to_world)
        
        if init_gs is None:
            nq, bs, gaussian_channel = refpoints_unsigmoid.shape
        else:
            nq = self.cfg.model.dino_decoder.num_queries
            bs = init_gs.shape[0]
        sqrt_nq = int(math.sqrt(nq))
        if init_gs is None:
            gs_init = refpoints_unsigmoid.permute(1, 2, 0).reshape(bs, gaussian_channel, sqrt_nq, sqrt_nq) # (B, 23, sqrt_nq, sqrt_nq)
        else:
            gs_init = init_gs # (bs, 24, point_cloud_number)

        
        ### change gs_output from d,xyz to xyz
        if point_cloud_init is not None:
            gs_outputs = torch.cat([point_cloud_init, gs_init], dim=1)
        else:
            gs_outputs = gs_init

        reference_points, pos_3d, pos_world = self.ref3d_proj.project_3d_gaussian_to_uv(gs_outputs, intrinsics, const_offset, source_cameras_view_to_view1)
        memory_views = memory.reshape(memory.shape[0], bs, N_views, memory.shape[-1])
        
        ref_points = [reference_points]

        reference_points = reference_points.permute(1, 0, 2)  # (nq, bs, 2)

        for layer_id, layer in enumerate(self.layers):
            if point_cloud_init is not None:
                # pos_world = gs_outputs[:, :3, :, :]
                pos_world = self.ref3d_proj.flatten_vector(pos_world)
            else:
                depth, offset = gs_outputs[:, 0, :, :].unsqueeze(1), gs_outputs[:, 1:4, :, :]
                pos_world = self.ref3d_proj.get_pos_from_network_output(depth, offset, const_offset=const_offset)
                pos_world = self.ref3d_proj.flatten_vector(pos_world)

            query_sine_embed = gen_sineembed_for_3Dposition(pos_world).permute(1, 0, 2) #(nq, bs, 384)

            raw_query_pos = self.ref_point_head(query_sine_embed)
            pos_scale = 1
            query_pos = pos_scale * raw_query_pos

            if self.deformable_decoder: # True
                reference_points_input = reference_points[:, :, None] * valid_ratios[None]  # (nq, bs, nlevel, 2)
                    
            # random drop some layers if needed
            dropflag = False
            if self.dec_layer_dropout_prob is not None:
                prob = random.random()
                if prob < self.dec_layer_dropout_prob[layer_id]:
                    dropflag = True

            if not dropflag:
                output = layer(
                    tgt=output,
                    tgt_query_pos=query_pos,
                    tgt_reference_points=reference_points_input,  # (nq, bs, n_level, 4)

                    memory=memory_views,
                    memory_key_padding_mask=memory_key_padding_mask,
                    memory_level_start_index=level_start_index,
                    memory_spatial_shapes=spatial_shapes,

                    self_attn_mask=tgt_mask,
                    cross_attn_mask=memory_mask,

                    pos=pos_world,
                    down_sample_rate=self.cfg.model.dino_decoder.down_sample_rate,
                    cam_emb=cam_emb
                )
           
            # iter update

            # get depth seperately for views
            
            dec_out = output.transpose(0, 1) #(bs, nq, 256)
            pos_world = self.out_head(dec_out, self.outs_pos[layer_id]) #(bs, 3, sqrt_nq, sqrt_nq)
            # normalize pos delta pos in the visual cone
            
            # get other parameters togther to get query
            dec_out = self.out_head(dec_out, self.outs_others[layer_id]) # (bs, 20, sqrt_nq, sqrt_nq)

            # update query to next layer and delta Gaussian params
            delta_gs = torch.cat([pos_world, dec_out], dim=1)

            gs_outputs_before_quater = delta_gs[:, :7, :, :] + gs_outputs[:, :7, :, :]
            gs_outputs_after_quater = delta_gs[:, 11:, :, :] + gs_outputs[:, 11:, :, :]
            delta_quater = delta_gs[:, 7:11, :, :].flatten(2).permute(0, 2, 1)
            gs_out_quater = gs_outputs[:, 7:11, :, :].flatten(2).permute(0, 2, 1)
            
            new_quaternion = quaternion_raw_multiply(delta_quater, gs_out_quater)
            gs_outputs_quater = new_quaternion.permute(0, 2, 1).reshape(bs, 4, sqrt_nq, sqrt_nq)
            gs_outputs = torch.cat([gs_outputs_before_quater, gs_outputs_quater, gs_outputs_after_quater], dim=1)

            
            new_reference_points, _, _ = self.ref3d_proj.project_3d_gaussian_to_uv(gs_outputs, intrinsics, const_offset, source_cameras_view_to_view1)
            new_reference_points = new_reference_points.permute(1, 0, 2)  # (nq, bs, 2)
            
            if self.rm_detach and 'dec' in self.rm_detach:
                reference_points = new_reference_points
            else:
                reference_points = new_reference_points.detach()

            ref_points.append(new_reference_points)
            # calculate G from query
            
            itm_gs_preds.append(gs_outputs)
            intermediate.append(output)


        return [
            [itm_out.transpose(0, 1) for itm_out in intermediate],
            [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
            itm_gs_preds,
        ]


class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(self, cfg, d_model=256, d_ffn=1024,
                 dropout=0.1, activation="relu",
                 n_levels=4, n_heads=8, n_points=4,
                 key_aware_type=None,
                 decoder_sa_type='ca',
                 module_seq=['sa', 'ca', 'ffn'],
                 eps=1e-6
                 ):
        super().__init__()
        self.cfg = cfg
        self.module_seq = module_seq
        # cam modulation
        self.modln = ModLN(inner_dim=d_model, mod_dim=self.cfg.model.dino_decoder.camera_embed_dim, eps=eps)
        self.modln_query = ModLN(inner_dim=d_model, mod_dim=self.cfg.model.dino_decoder.camera_embed_dim, eps=eps)


        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points, n_views=cfg.data.input_images)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

        self.key_aware_type = key_aware_type
        self.key_aware_proj = None
        self.decoder_sa_type = decoder_sa_type
        assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']

    def rm_self_attn_modules(self):
        self.self_attn = None
        self.dropout2 = None
        self.norm2 = None

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_sa(self,
                   # for tgt
                   tgt: Optional[Tensor],  # nq, bs, d_model
                   tgt_query_pos: Optional[Tensor] = None,  # pos for query. MLP(Sine(pos))
                   tgt_reference_points: Optional[Tensor] = None,  # nq, bs, 4

                   # for memory
                   memory: Optional[Tensor] = None,  # hw, bs, d_model
                   memory_key_padding_mask: Optional[Tensor] = None,
                   memory_level_start_index: Optional[Tensor] = None,  # num_levels
                   memory_spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2

                   # sa
                   self_attn_mask: Optional[Tensor] = None,  # mask used for self-attention
                   cross_attn_mask: Optional[Tensor] = None,  # mask used for cross-attention
                   pos: Optional[Tensor] = None,
                   down_sample_rate: Optional[Tensor] = None
                   ):
        # self attention
        if self.self_attn is not None:
            q = k = self.with_pos_embed(tgt, tgt_query_pos)
            
            if down_sample_rate < 1 and pos is not None:
                k = down_sample_query(k, pos, down_sample_rate)
                v = down_sample_query(tgt, pos, down_sample_rate)
            else:
                v = tgt
            tgt2 = self.self_attn(q, k, v, attn_mask=self_attn_mask)[0]
            tgt = tgt + self.dropout2(tgt2)
            tgt = self.norm2(tgt)

        return tgt

    def forward_ca(self,
                   # for tgt
                   tgt: Optional[Tensor],  # nq, bs, d_model
                   tgt_query_pos: Optional[Tensor] = None,  # pos for query. MLP(Sine(pos))
                   tgt_reference_points: Optional[Tensor] = None,  # nq, bs, 4

                   # for memory
                   memory: Optional[Tensor] = None,  # hw, bs, d_model
                   memory_key_padding_mask: Optional[Tensor] = None,
                   memory_level_start_index: Optional[Tensor] = None,  # num_levels
                   memory_spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2

                   # sa
                   self_attn_mask: Optional[Tensor] = None,  # mask used for self-attention
                   cross_attn_mask: Optional[Tensor] = None,  # mask used for cross-attention

                   cam_emb: Optional[Tensor] = None,
                   modln_query: Optional[nn.Module] = None
                   ):
        # cross attention
        if self.key_aware_type is not None:

            if self.key_aware_type == 'mean':
                tgt = tgt + memory.mean(0, keepdim=True)
            elif self.key_aware_type == 'proj_mean':
                tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
            else:
                raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type))
        tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
                               tgt_reference_points.transpose(0, 1).contiguous(),
                            #    memory.transpose(0, 1), 
                               memory,
                               memory_spatial_shapes, 
                               memory_level_start_index, 
                               memory_key_padding_mask,
                               cam_emb=cam_emb,
                               modln_fn=modln_query).transpose(0, 1)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        return tgt

    def forward(self,
                # for tgt
                tgt: Optional[Tensor],  # nq, bs, d_model
                tgt_query_pos: Optional[Tensor] = None,  # pos for query. MLP(Sine(pos))
                tgt_reference_points: Optional[Tensor] = None,  # nq, bs, 4

                # for memory
                memory: Optional[Tensor] = None,  # hw, bs, d_model
                memory_key_padding_mask: Optional[Tensor] = None,
                memory_level_start_index: Optional[Tensor] = None,  # num_levels
                memory_spatial_shapes: Optional[Tensor] = None,  # bs, num_levels, 2

                # sa
                self_attn_mask: Optional[Tensor] = None,  # mask used for self-attention
                cross_attn_mask: Optional[Tensor] = None,  # mask used for cross-attention,
                pos: Optional[Tensor] = None,
                down_sample_rate: Optional[Tensor] = None,
                cam_emb: Optional[Tensor] = None
                ):
        for funcname in self.module_seq:
            
            if self.cfg.data.mod_camera_dec:
                memory = self.modln(memory, cam_emb)
            if funcname == 'ffn':
                tgt = self.forward_ffn(tgt)

            elif funcname == 'ca':
                tgt = self.forward_ca(tgt, tgt_query_pos,
                                      tgt_reference_points,
                                      memory, memory_key_padding_mask, memory_level_start_index,
                                      memory_spatial_shapes, self_attn_mask, cross_attn_mask, cam_emb, self.modln_query)
            elif funcname == 'sa':
                tgt = self.forward_sa(tgt, tgt_query_pos,
                                      tgt_reference_points,
                                      memory, memory_key_padding_mask, memory_level_start_index,
                                      memory_spatial_shapes, self_attn_mask, cross_attn_mask, pos, down_sample_rate)
            
            else:
                raise ValueError('unknown funcname {}'.format(funcname))

        return tgt

class ModLN(nn.Module):
    """
    Modulation with adaLN.
    
    References:
    DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
    """
    def __init__(self, inner_dim: int, mod_dim: int, eps: float):
        super().__init__()
        self.norm = nn.LayerNorm(inner_dim, eps=eps)
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(mod_dim, inner_dim * 2),
        )

    @staticmethod
    def modulate(x, shift, scale):
        # x: [HxW, B, D]
        # shift, scale: [B, D]
        return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)

    def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
        shift, scale = self.mlp(mod).chunk(2, dim=-1)  # [B, D]
        return self.modulate(self.norm(x), shift, scale)  # [N, D]

def _get_clones(module, N, layer_share=False):
    if layer_share:
        return nn.ModuleList([module for i in range(N)])
    else:
        return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def build_deformable_transformer(cfg, out_dim):

    num_feature_levels = 4
    stage1_queries = cfg.model.dino_decoder.num_queries


    return DeformableTransformer(
        cfg,
        d_model=cfg.model.dino_decoder.hidden_dim,
        dropout=cfg.model.dino_decoder.dropout,
        nhead=cfg.model.dino_decoder.nheads,
        num_queries=cfg.model.dino_decoder.num_queries,
        dim_feedforward=cfg.model.dino_decoder.dim_feedforward,
        num_decoder_layers=cfg.model.dino_decoder.num_decoder_layers,
        query_dim=cfg.model.dino_decoder.query_dim,
        activation=cfg.model.dino_decoder.transformer_activation,
        modulate_hw_attn=True,
        deformable_decoder=True,
        num_feature_levels=num_feature_levels,
        dec_n_points=cfg.model.dino_decoder.dec_n_points,

        rm_detach=None,
        decoder_sa_type=cfg.model.dino_decoder.decoder_sa_type,
        module_seq=cfg.model.dino_decoder.decoder_module_seq,

        out_dim=out_dim,
        stage1_queries=stage1_queries
    )
