import copy
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
import math
from models.net_utils import MLP, gen_sineembed_for_position, inverse_sigmoid, box_cxcywh_to_xyxy,generalized_box_iou,greater_than_indices

from .position_encoding import SeqEmbeddingLearned, SeqEmbeddingSine
from .attention import MultiheadAttention
from easydict import EasyDict as EDict
import numpy as np

class QueryDecoder(nn.Module):
    
    def __init__(self, cfg):
        super().__init__()    
        d_model = cfg.MODEL.SAVGDETR.HIDDEN
        nhead = cfg.MODEL.SAVGDETR.HEADS
        dim_feedforward = cfg.MODEL.SAVGDETR.FFN_DIM
        dropout = cfg.MODEL.SAVGDETR.DROPOUT
        activation = "relu"
        num_layers = cfg.MODEL.SAVGDETR.DEC_LAYERS
        
        self.d_model = d_model
        self.query_pos_dim = cfg.MODEL.SAVGDETR.QUERY_DIM
        self.nhead = nhead
        self.video_max_len = cfg.INPUT.MAX_VIDEO_LEN
        self.return_weights = cfg.SOLVER.USE_ATTN
        return_intermediate_dec = True

        self.template_generator = TemplateGenerator(cfg)
        decoder_layer = TransformerDecoderLayer(
            cfg,
            d_model,
            nhead,
            dim_feedforward,
            dropout,
            activation
        )

        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(
            cfg,
            decoder_layer,
            num_layers,
            decoder_norm,
            return_intermediate=return_intermediate_dec,
            return_weights=self.return_weights,
            d_model=d_model,
            query_dim=self.query_pos_dim
        )
        # decoder_layer = UAVSVGDecoderLayer(
        #     d_model,
        #     nhead,
        #     dim_feedforward,
        #     dropout,
        #     activation
        # )
        # decoder_norm = nn.LayerNorm(d_model)
        # self.decoder = UAVSVGDecoder(
        #     decoder_layer,
        #     num_layers,
        #     decoder_norm,
        #     return_intermediate=return_intermediate_dec,
        #     return_weights=self.return_weights,
        # )

        # The position embedding of global tokens
        if cfg.MODEL.SAVGDETR.USE_LEARN_TIME_EMBED:
            self.time_embed = SeqEmbeddingLearned(self.video_max_len, d_model)
        else:
            self.time_embed = SeqEmbeddingSine(self.video_max_len, d_model)
    
        self._reset_parameters()
        
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def forward(self, memory_cache, vis_pos=None, vis_poss4=None, vis_poss3=None, query_embed= None):
        encoded_memory = memory_cache["encoded_memory"]
        memory_mask = memory_cache["mask"]
        durations = memory_cache["durations"]
        fea_map_size = memory_cache["fea_map_size"]  # (H,W) the feature map size
        text_memory_mask = memory_cache["text_mask"]
        vis_mask = memory_cache["vis_mask"]
        n_vis_tokens = fea_map_size[0] * fea_map_size[1]
        # fea_map_sizes = memory_cache["fea_map_size"]  # (H,W) the feature map size
        # n_vis_tokens = 0
        # for fea_map_size in fea_map_sizes:
        #     n_vis_tokens += fea_map_size[0] * fea_map_size[1]

        # the contextual feature to generate dynamic learnable anchors
        frames_cls = memory_cache["frames_cls"]  # n_frames x d_model
        videos_cls = memory_cache["videos_cls"]  # the video-level gloabl contextual token, b x d_model


        b = len(durations)
        t = max(durations)
        device = encoded_memory.device

        pos_query, content_query = self.template_generator(frames_cls, videos_cls, durations)

        pos_query = pos_query.sigmoid()
        pos_query = torch.split(pos_query, durations, dim=0)
        # tgt = torch.zeros(t, b, self.d_model).to(device)  # [n_frames, bs, d_model]
        content_query = content_query.expand(t, content_query.size(-1)).unsqueeze(1)  # [n_frames, bs, d_model]
        tgt = content_query

        # The position embedding of query
        query_pos_embed = torch.zeros(b, t, self.query_pos_dim).to(device)
        query_mask = torch.zeros(b, t).bool().to(device)
        query_mask[:, 0] = False  # avoid empty masks

        for i_dur, dur in enumerate(durations):
            query_mask[i_dur, : dur] = False
            query_pos_embed[i_dur, : dur, :] = pos_query[i_dur]

        query_pos_embed = query_pos_embed.permute(1, 0, 2)  # [n_frames, bs, 4]
        query_time_embed = self.time_embed(t).repeat(1, b, 1)  # [n_frames, bs, d_model]
        memory_pos_embed = vis_pos.flatten(2).permute(2, 0, 1)
        # memory_pos_embed = torch.cat([memory_pos_embed, torch.zeros_like(encoded_memory[n_vis_tokens:])], dim=0)

        text_memory = encoded_memory[n_vis_tokens:]
        img_memory = encoded_memory[:n_vis_tokens]

        img_memory = memory_cache["multiscale_vis"]
        vis_mask = memory_cache["multiscale_mask"]
        memory_pos_embed = memory_cache["multiscale_pos"]
        # memory_pos_embed = torch.cat([vis_poss3.flatten(2).permute(2, 0, 1),
        #                            vis_poss4.flatten(2).permute(2, 0, 1),
        #                            vis_pos.flatten(2).permute(2, 0, 1)], dim=0)
        text_memory = memory_cache["multi_text_feats"]


        outputs = self.decoder(
            tgt,  # t x b x c
            img_memory,  # n_tokens x n_frames x c
            tgt_key_padding_mask=query_mask,  # bx(t*n_queries)
            memory_key_padding_mask=vis_mask,  # n_frames * n_tokens
            pos=memory_pos_embed,  # n_tokens x n_frames x c
            query_anchor=query_pos_embed,  # n_queriesx(b*t)xF
            query_time=query_time_embed,
            durations=durations,
            fea_map_size=fea_map_size,
            text_memory=text_memory,
            text_memory_mask=text_memory_mask,
        )
        # outputs = self.decoder(
        #     tgt,    # T x b x c
        #     encoded_memory,   # n_tokens x T x c
        #     memory_key_padding_mask=memory_mask,  # T * n_tokens
        #     pos=memory_pos_embed,   # n_tokens x T x c
        #     tgt_key_padding_mask=query_mask,  # bx(t*n_queries)
        #     query_pos=query_embed,    # [T, b, d_model]
        #
        #     text_memory = text_memory_resized,
        #     text_memory_mask = text_attention_mask
        # )

        return outputs
        


class TransformerDecoder(nn.Module):
    def __init__(self, cfg, decoder_layer, num_layers, norm=None, return_intermediate=False,
                    return_weights=False, d_model=256, query_dim=4):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate
        self.return_weights = return_weights
        self.query_dim = query_dim
        self.d_model = d_model
        
        self.query_scale = MLP(d_model, d_model, d_model, 2)
        self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
        self.bbox_embed = None

        # self.conf = MLP(d_model, d_model, 1, 2, dropout=0.3)
        # self.conf_multiscal = MLP(d_model, d_model, 1, 2, dropout=0.3)

        # self.gf_mlp = MLP(d_model, d_model, d_model, 2)
        # self.gf_mlp_multiscal = MLP(d_model, d_model, d_model, 2)

        
        for layer_id in range(num_layers - 1):
            self.layers[layer_id + 1].ca_qpos_proj = None

        self.theta_s_gt = cfg.MODEL.SAVGDETR.SPAT_GT_THETA
        self.theta_s = cfg.MODEL.SAVGDETR.SPAT_THETA

    def gt_info(self, targets):
        if "eval" in targets[0].keys():
            target_boxes = torch.cat([target["boxs"] for target in targets], dim=0)
        else:
            target_boxes = torch.cat([target["boxs"].bbox for target in targets], dim=0)

        gt_bbox_slice = []
        durations = targets[0]["durations"]

        max_duration = max(durations)
        for i_dur, (duration, target) in enumerate(zip(durations, targets)):
            inter = torch.where(target['actioness'])[0].cpu().numpy().tolist()
            gt_bbox_slice.extend(list(range(i_dur * max_duration + inter[0], i_dur * max_duration + inter[-1] + 1)))
        gt_bbox_slice = torch.LongTensor(gt_bbox_slice)
        return target_boxes, gt_bbox_slice

    def get_context_index_by_gt(self, pred_boxes, target_boxes, gt_bbox_slice, conf_threshold):
        conf_list = torch.zeros(pred_boxes.shape[0]).to(pred_boxes.device)
        if len(gt_bbox_slice) > 1:
            pred_boxes = pred_boxes[gt_bbox_slice].squeeze()
        else:
            pred_boxes = pred_boxes[gt_bbox_slice][0]

        iou = generalized_box_iou(box_cxcywh_to_xyxy(pred_boxes), box_cxcywh_to_xyxy(target_boxes))

        conf_list[gt_bbox_slice[0]:gt_bbox_slice[-1] + 1] = torch.diag(iou)

        context_index = greater_than_indices(conf_list, conf_threshold).reshape(-1)

        return context_index

    def get_context_index_by_our(self, conf_list, conf_threshold=0.7):
        context_index = greater_than_indices(conf_list, conf_threshold).reshape(-1)
        return context_index

    def generate_roi_feature(self, all_feature, v_size, bboxs, type):
        feature_map = all_feature.permute(1, 0, 2)[:, :v_size[0] * v_size[1]]  # [b, H*W, C]

        feature_map = feature_map.reshape(-1, v_size[0], v_size[1], all_feature.size(-1))  # [b, H, W, C]
        bboxs = box_cxcywh_to_xyxy(bboxs).clamp(min=0).squeeze() * torch.Tensor(
            [v_size[1], v_size[0], v_size[1], v_size[0]]).to(bboxs.device)
        bboxs = torch.stack([(bboxs[:, 0]).round(), (bboxs[:, 1]).round(), (bboxs[:, 2]).ceil(), (bboxs[:, 3]).round()],
                            dim=-1).long()  # torch.round(bboxs).int()
        roi_feature = []
        for i in range(len(bboxs)):
            f = feature_map[i].clone()
            x1, y1, x2, y2 = bboxs[i]
            x2 = min(max(x2, 1), f.size(1))
            x1 = min(max(x1, 0), x2 - 1)
            y2 = min(max(y2, 1), f.size(0))
            y1 = min(max(y1, 0), y2 - 1)
            try:
                r = f[y1:y2, x1:x2].clone().reshape(-1, all_feature.size(-1))
                pooling_r = torch.mean(r, dim=0)
            except:
                pooling_r = torch.zeros(256).to(bboxs.device)
            roi_feature.append(pooling_r)
        return torch.stack(roi_feature)

    # def generate_context(self, roi_2d, index=None):
    #     context_2d = self.gf_mlp(roi_2d[index])
    #     return context_2d
    def generate_context(self, roi_2d, roi_2d_multiscal, index=None):
        context_2d = self.gf_mlp(roi_2d[index])
        context_multiscal = self.gf_mlp_multiscal(roi_2d_multiscal[index])
        context = torch.cat((context_2d, context_multiscal), dim=0)
        return context

    def update_memory_key_padding_mask(self, memory_key_padding_mask, index, v_size, bboxs):
        mask_map = memory_key_padding_mask[:, :v_size[0] * v_size[1]]  # [bT, H*W+Nt] -> [bT, H*W]
        mask_map = mask_map.reshape(-1, v_size[0], v_size[1])  # [bT, H, W]
        bboxs = box_cxcywh_to_xyxy(bboxs).clamp(min=0).squeeze() * torch.Tensor([v_size[1], v_size[0], v_size[1], v_size[0]]).to(bboxs.device)
        bboxs = torch.stack([(bboxs[:, 0]).round(), (bboxs[:, 1]).round(), (bboxs[:, 2]).ceil(), (bboxs[:, 3]).round()],dim=-1).long()  # torch.round(bboxs).int()

        new_mask = []
        for i in range(len(bboxs)):
            mask = mask_map[i]
            if i in index:
                # 创建一个与mask相同形状的全True张量
                mask_expanded = torch.ones_like(mask, dtype=torch.bool)
                x1, y1, x2, y2 = bboxs[i]
                x2 = min(max(x2, 1), mask.size(1))
                x1 = min(max(x1, 0), x2 - 1)
                y2 = min(max(y2, 1), mask.size(0))
                y1 = min(max(y1, 0), y2 - 1)
                mask_expanded[y1:y2, x1:x2] = False
                new_mask.append(mask_expanded)
            else:
                new_mask.append(mask)
        vis_mask = torch.stack(new_mask).flatten(1)  # [bT, H, W] -> [bT, H*W]
        update_memory_mask = torch.cat([vis_mask, memory_key_padding_mask[:, v_size[0] * v_size[1]:]], dim=1)
        assert update_memory_mask.shape == memory_key_padding_mask.shape
        return update_memory_mask

    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,   # the pos for feature map
        query_anchor: Optional[Tensor] = None, # the anchor pos embedding
        query_time = None,   # the query time position embedding
        durations=None,
        fea_map_size = None,
            targets=None,
        iteration_rate=None,
        text_memory=None,
        text_memory_mask=None,
    ):
        conf_list = []
        output = tgt
        intermediate = []
        intermediate_weights = []
        ref_anchors = [query_anchor]   # the query pos is like t x b x 4
        context = None

        for layer_id, layer in enumerate(self.layers):
            obj_center = query_anchor[..., :self.query_dim]     # [num_queries, batch_size, 4]
            # get sine embedding for the query vector
            query_sine_embed = gen_sineembed_for_position(obj_center)  
            query_pos = self.ref_point_head(query_sine_embed)    # generated the position embedding
            
            # For the first decoder layer, we do not apply transformation over p_s
            if layer_id == 0:
                pos_transformation = 1
            else:
                pos_transformation = self.query_scale(output)

            # apply transformation
            query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformation


            # # 调制模块 Instance Context Generation (ICG) + Instance Context Refinement (ICR)
            # target_boxes, gt_bbox_slice = self.gt_info(targets)
            # # roi_2d_scal3 = self.generate_roi_feature(multiscalvis[0], feature_size[0], query_anchor, "2d")
            # # roi_2d_scal4 = self.generate_roi_feature(multiscalvis[1], feature_size[1], query_anchor, "2d")
            # roi_2d = self.generate_roi_feature(memory, feature_size, query_anchor, "2d")  # [num_t, d_model]
            # # roi_2d_multiscal = self.generate_roi_feature(multiscalvis[0], feature_size[1], query_anchor, "2d")
            #
            # conf_2d = self.conf(roi_2d).sigmoid().squeeze()  # confidence scores [num_t]
            # # conf_2d_multiscal = self.conf_multiscal(roi_2d_multiscal).sigmoid().squeeze()
            # #
            # # conf_list.append((conf_2d+conf_2d_multiscal)/2)
            # conf_list.append(conf_2d)
            # if iteration_rate >= 0:
            #     context_indexs = self.get_context_index_by_gt(query_anchor, target_boxes, gt_bbox_slice,
            #                                                   self.theta_s_gt)
            # else:
            #     context_indexs = self.get_context_index_by_our(conf_list[-1], self.theta_s)

            # context_scal3 = self.generate_context(roi_2d_scal3, context_indexs) if len(context_indexs) > 0 else None
            # context_scal4 = self.generate_context(roi_2d_scal4, context_indexs) if len(context_indexs) > 0 else None
            # context = self.generate_context(roi_2d, context_indexs) if len(context_indexs) > 0 else None
            # context = self.generate_context(roi_2d,roi_2d_multiscal, context_indexs) if len(context_indexs) > 0 else None

            # if context != None:
            #     context = torch.cat([context, context_scal4, context_scal3, ], dim=0)

            # TODO:利用 context_indexs 和 query_anchor 把对应帧对应区域的之外的patch给忽略，修改memory_key_padding_mask
            # update_memory_mask = self.update_memory_key_padding_mask(memory_key_padding_mask, context_indexs, feature_size, query_anchor)

            if layer_id in [0]:
                output, img_weights = layer(output, memory[-1], tgt_mask=tgt_mask,
                               memory_mask=memory_mask,
                               tgt_key_padding_mask=tgt_key_padding_mask,
                               memory_key_padding_mask=memory_key_padding_mask[-1],
                               pos=pos[-1], query_pos=query_pos, query_time_embed=query_time,
                               query_sine_embed=query_sine_embed,
                               context=context,
                               durations=durations,is_first=(layer_id == 0),
                             text_memory=text_memory[-1],
                             text_memory_mask=text_memory_mask,
                             )
            elif layer_id in [1, 2]:
                output, img_weights = layer(output, memory[-2], tgt_mask=tgt_mask,
                               memory_mask=memory_mask,
                               tgt_key_padding_mask=tgt_key_padding_mask,
                               memory_key_padding_mask=memory_key_padding_mask[-2],
                               pos=pos[-2], query_pos=query_pos, query_time_embed=query_time,
                               query_sine_embed=query_sine_embed,
                               context=context,
                               durations=durations,is_first=(layer_id == 0),
                             text_memory=text_memory[-2],
                             text_memory_mask=text_memory_mask,
                             )
            else:
                output, img_weights = layer(output, memory[-3], tgt_mask=tgt_mask,
                               memory_mask=memory_mask,
                               tgt_key_padding_mask=tgt_key_padding_mask,
                               memory_key_padding_mask=memory_key_padding_mask[-3],
                               pos=pos[-3], query_pos=query_pos, query_time_embed=query_time,
                               query_sine_embed=query_sine_embed,
                               context=context,
                               durations=durations,is_first=(layer_id == 0),
                             text_memory=text_memory[-3],
                             text_memory_mask=text_memory_mask,
                             )

            
            # iter update
            if self.bbox_embed is not None:
                tmp = self.bbox_embed(output)    # t, b, 4
                tmp[..., :self.query_dim] += inverse_sigmoid(query_anchor) # offset + anchor
                new_query_anchor = tmp[..., :self.query_dim].sigmoid()
                if layer_id != self.num_layers - 1:
                    ref_anchors.append(new_query_anchor)
                query_anchor = new_query_anchor.detach()


            if self.return_intermediate:
                intermediate.append(self.norm(output))
                if self.return_weights:
                    intermediate_weights.append(img_weights)

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)
        
        if self.return_intermediate:
            if self.bbox_embed is not None:
                outputs = [
                    # torch.stack(ref_anchors).transpose(1, 2),
                    # torch.stack(conf_list),

                    torch.stack(intermediate).transpose(1, 2),
                    torch.stack(ref_anchors).transpose(1, 2),
                    # torch.stack(conf_list),
                ]
            else:
                outputs = [
                    torch.stack(intermediate).transpose(1, 2), 
                    query_anchor.unsqueeze(0).transpose(1, 2)
                ]
        
        if self.return_weights:
            # return outputs, torch.stack(intermediate_weights)
            return outputs, intermediate_weights
        else:
            return outputs, None


class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        cfg,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
    ):
        super().__init__()
        # Decoder Self-Attention
        self.sa_qcontent_proj = nn.Linear(d_model, d_model)
        self.sa_qpos_proj = nn.Linear(d_model, d_model)
        self.sa_qtime_proj = nn.Linear(d_model, d_model)
        self.sa_kcontent_proj = nn.Linear(d_model, d_model)
        self.sa_kpos_proj = nn.Linear(d_model, d_model)
        self.sa_ktime_proj = nn.Linear(d_model, d_model)
        self.sa_v_proj = nn.Linear(d_model, d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)
        
        # Decoder Cross-Attention
        self.ca_qcontent_proj = nn.Linear(d_model, d_model)
        self.ca_qpos_proj = nn.Linear(d_model, d_model)
        self.ca_qtime_proj = nn.Linear(d_model, d_model)
        self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
        self.ca_kcontent_proj = nn.Linear(d_model, d_model)
        self.ca_kpos_proj = nn.Linear(d_model, d_model)
        self.ca_v_proj = nn.Linear(d_model, d_model)



        self.cross_attn_text = nn.MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)


        self.from_scratch_cross_attn = cfg.MODEL.SAVGDETR.FROM_SCRATCH
        self.cross_attn_image = None
        self.cross_attn = None
        self.tgt_proj = None
        
        if self.from_scratch_cross_attn:
            self.cross_attn = MultiheadAttention(d_model * 2, nhead, dropout=dropout, vdim=d_model)
        else:
            self.cross_attn_image = nn.MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model) 
        
        self.nhead = nhead
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        # self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        # self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.dropout4 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)



        # bert_config = EDict(
        #     num_attention_heads=8,
        #     hidden_size=256,
        #     attention_head_size=256,
        #     attention_probs_dropout_prob=0.1,
        #     layer_norm_eps=1e-12,
        #     hidden_dropout_prob=0.1,
        #     intermediate_size=256
        # )
        # self.ca_query = BertLayer_Cross(bert_config)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
        query_time_embed=None,
        query_sine_embed = None,
            context=None,
        durations=None,
        is_first = False,
        text_memory=None,
        text_memory_mask=None,

    ):
        # Apply projections here
        # shape: num_queries x batch_size x 256
        # ========== Begin of Self-Attention =============
        q_content = self.sa_qcontent_proj(tgt)      # target is the input of the first decoder layer. zero by default.
        q_time = self.sa_qtime_proj(query_time_embed)
        q_pos = self.sa_qpos_proj(query_pos)
        k_content = self.sa_kcontent_proj(tgt)
        k_time = self.sa_ktime_proj(query_time_embed)
        k_pos = self.sa_kpos_proj(query_pos)
        v = self.sa_v_proj(tgt)
        
        # q = k = tgt + query_time_embed + query_pos
        # v = tgt

        q = q_content + q_time + q_pos
        k = k_content + k_time + k_pos

        # Temporal Self attention
        # print(q.shape, k.shape, v.shape,  tgt_key_padding_mask) # [T, b, D] [T, b, D] [T, b, D] None [b, T] b=1
        tgt2, weights = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
                key_padding_mask=tgt_key_padding_mask)

        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        # ========== End of Self-Attention =============


        # if context != None:
        #     tgt = self.ca_query(tgt, context.unsqueeze(0).expand(tgt_key_padding_mask.size(-1),
        #                                     context.size(0),context.size(1)))


        # ========== Begin of Text Cross-Attention =============
        tgt2, cross_weights = self.cross_attn_text(
            query=self.with_pos_embed(tgt.transpose(0, 1), query_pos.transpose(0, 1)),  # [1, T, D]
            key=text_memory,  # [Nt, T, D]
            value=text_memory,  # [Nt, T, D]
            attn_mask=memory_mask,
            key_padding_mask=text_memory_mask,  # [Nt, T, D]
        )
        tgt = tgt.transpose(0, 1) + self.dropout2(tgt2)
        tgt = tgt.transpose(0, 1)  # bxtxf -> txbxf
        tgt = self.norm2(tgt)
        # ========== End of Cross-Attention =============


        # ========== Begin of Cross-Attention =============
        # Time Aligned Cross attention
        t, b, c = tgt.shape    # b is the video number
        n_tokens, bs, f = memory.shape   # bs is the total frames in a batch
        assert f == c   # all the token dim should be same
        
        q_content = self.ca_qcontent_proj(tgt)
        k_content = self.ca_kcontent_proj(memory)
        v = self.ca_v_proj(memory)
        k_pos = self.ca_kpos_proj(pos)
        if is_first:
            q_pos = self.ca_qpos_proj(query_pos)
            q = q_content + q_pos
            k = k_content + k_pos
        else:
            q = q_content
            k = k_content

        # v = memory
        # k_pos = pos
        # if is_first:
        #     q = tgt + query_pos
        #     k = memory + k_pos
        # else:
        #     q = tgt
        #     k = memory

            
        q = q.view(t, b, self.nhead, c // self.nhead)
        query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
        query_sine_embed = query_sine_embed.view(t, b, self.nhead, c // self.nhead)
        
        if self.from_scratch_cross_attn:
            q = torch.cat([q, query_sine_embed], dim=3).view(t, b, c * 2)
        else:
            q = (q + query_sine_embed).view(t, b, c)
            q = q + self.ca_qtime_proj(query_time_embed)
            # q = q + query_time_embed
        
        k = k.view(n_tokens, bs, self.nhead, f//self.nhead)
        k_pos = k_pos.view(n_tokens, bs, self.nhead, f//self.nhead)
        
        if self.from_scratch_cross_attn:
            k = torch.cat([k, k_pos], dim=3).view(n_tokens, bs, f * 2)
        else:
            k = (k + k_pos).view(n_tokens, bs, f)
            
        # extract the actual video length query
        clip_start = 0
        device = tgt.device
        if self.from_scratch_cross_attn:
            q_cross = torch.zeros(1,bs,2 * c).to(device)
        else:
            q_cross = torch.zeros(1,bs,c).to(device)
        
        for i_b in range(b):
            q_clip = q[:,i_b,:]   # t x f
            clip_length = durations[i_b]
            q_cross[0,clip_start : clip_start + clip_length] = q_clip[:clip_length]
            clip_start += clip_length
        
        assert clip_start == bs
        
        if self.from_scratch_cross_attn:
            tgt2, img_weights = self.cross_attn(
                query=q_cross,
                key=k,
                value=v,
                attn_mask=memory_mask,
                key_padding_mask=memory_key_padding_mask,
            )
        else:
            tgt2, img_weights = self.cross_attn_image(
                query=q_cross,
                key=k,
                value=v,
                attn_mask=memory_mask,
                key_padding_mask=memory_key_padding_mask,
            )
     
        # reshape to the batched query
        clip_start = 0
        tgt2_pad = torch.zeros(1,t*b,c).to(device)
        
        for i_b in range(b):
            clip_length = durations[i_b]
            tgt2_pad[0,i_b * t:i_b * t + clip_length] = tgt2[0,clip_start : clip_start + clip_length]
            clip_start += clip_length

        tgt2 = tgt2_pad
        tgt2 = tgt2.view(b, t, f).transpose(0, 1)  # 1x(b*t)xf -> bxtxf -> txbxf

        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        
        # FFN
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm4(tgt)

        return tgt, img_weights


class TemplateGenerator(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.d_model = cfg.MODEL.SAVGDETR.HIDDEN
        self.pos_query_dim = cfg.MODEL.SAVGDETR.QUERY_DIM
        self.content_proj = nn.Linear(self.d_model, self.d_model)
        self.gamma_proj = nn.Linear(self.d_model, self.d_model)
        self.beta_proj = nn.Linear(self.d_model, self.d_model)
        self.anchor_proj = nn.Linear(self.d_model, self.pos_query_dim)

    def forward(
            self,
            frames_cls=None,
            videos_cls=None,  # [b, d_model]
            durations=None,
            text_cls=None  # [b, d_model]
    ):
        b = len(durations)
        frames_cls_list = torch.split(frames_cls, durations, dim=0)
        content_query = self.content_proj(videos_cls)

        pos_query = []
        temp_query = []
        for i_b in range(b):
            frames_cls = frames_cls_list[i_b]
            video_cls = videos_cls[i_b]
            gamma_vec = torch.tanh(self.gamma_proj(video_cls))
            beta_vec = torch.tanh(self.beta_proj(video_cls))
            pos_query.append(self.anchor_proj(gamma_vec * frames_cls + beta_vec))
            temp_query.append(content_query[i_b].unsqueeze(0).repeat(frames_cls.shape[0], 1))

        pos_query = torch.cat(pos_query, dim=0)
        temp_query = torch.cat(temp_query, dim=0)
        # return pos_query, None
        return pos_query, temp_query


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
