import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

from .net_utils import MLP, inverse_sigmoid,box_cxcywh_to_xyxy
from .vision_model import build_vis_encoder
from .language_model import build_text_encoder
from .grounding_model import build_encoder, build_decoder
from utils.misc import NestedTensor
from utils.misc import get_activation
from .uav_modules.block import *


class ConvNormLayer(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
        super().__init__()
        self.conv = nn.Conv2d(
            ch_in,
            ch_out,
            kernel_size,
            stride,
            padding=(kernel_size-1)//2 if padding is None else padding,
            bias=bias)
        self.norm = nn.BatchNorm2d(ch_out)
        self.act = nn.Identity() if act is None else get_activation(act)

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class RepVggBlock(nn.Module):
    def __init__(self, ch_in, ch_out, act='relu'):
        super().__init__()
        self.ch_in = ch_in
        self.ch_out = ch_out
        self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
        self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
        self.act = nn.Identity() if act is None else get_activation(act)

    def forward(self, x):
        if hasattr(self, 'conv'):
            y = self.conv(x)
        else:
            y = self.conv1(x) + self.conv2(x)

        return self.act(y)

    def convert_to_deploy(self):
        if not hasattr(self, 'conv'):
            self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)

        kernel, bias = self.get_equivalent_kernel_bias()
        self.conv.weight.data = kernel
        self.conv.bias.data = bias
        # self.__delattr__('conv1')
        # self.__delattr__('conv2')

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)

        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return F.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch: ConvNormLayer):
        if branch is None:
            return 0, 0
        kernel = branch.conv.weight
        running_mean = branch.norm.running_mean
        running_var = branch.norm.running_var
        gamma = branch.norm.weight
        beta = branch.norm.bias
        eps = branch.norm.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std


class CSPRepLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_blocks=3,
                 expansion=1.0,
                 bias=None,
                 act="silu"):
        super(CSPRepLayer, self).__init__()
        hidden_channels = int(out_channels * expansion)
        self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
        self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
        self.bottlenecks = nn.Sequential(*[
            RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
        ])
        if hidden_channels != out_channels:
            self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
        else:
            self.conv3 = nn.Identity()

    def forward(self, x):
        x_1 = self.conv1(x)
        x_1 = self.bottlenecks(x_1)
        x_2 = self.conv2(x)
        return self.conv3(x_1 + x_2)


class MultiscalMultimodalFusion(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.d_model = cfg.MODEL.SAVGDETR.HIDDEN

        self.gamma_proj = nn.ModuleList(nn.Linear(self.d_model, self.d_model) for _ in [2048, 1024, 512])
        self.beta_proj = nn.ModuleList(nn.Linear(self.d_model, self.d_model) for _ in [2048, 1024, 512])

        self.norms = nn.ModuleList([nn.InstanceNorm2d(self.d_model) for _ in [2048, 1024, 512]])

        self.joint_fusion = nn.ModuleList([self._make_conv(self.d_model, self.d_model, 1) for _ in [2048, 1024, 512]])

        self.downsample_conv = ConvNormLayer(self.d_model, self.d_model, 3, 2, act='silu')

        self.refine = self._make_conv(self.d_model, self.d_model, 1)


    def _make_conv(self, input_dim, output_dim, k, stride=1):
        pad = (k - 1) // 2
        return nn.Sequential(
            nn.Conv2d(input_dim, output_dim, (k, k), padding=(pad, pad), stride=(stride, stride)),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(inplace=True)
        )

    def forward(
            self,
            visu_feats=None,  # [[n_frames, d_model, h, w],[n_frames, d_model, h, w],[n_frames, d_model, h, w],]
            lang_feat=None,  # [b, d_model]
    ):
        n_frames = visu_feats[0].shape[0]
        batch = lang_feat.shape[0]

        text_fea_list = []
        for i_b in range(batch):
            text_fea_list.append(torch.stack([lang_feat[i_b] for _ in range(n_frames)],dim=0))
        lang_feat_resized = torch.cat(text_fea_list, dim=0)   # [n_frames, d_model]

        gamma_vec = [F.tanh(gamma(lang_feat_resized)) for gamma in self.gamma_proj]
        beta_vec = [F.tanh(beta(lang_feat_resized)) for beta in self.beta_proj]
        visu_feat = []
        for ii, feat in enumerate(visu_feats):
            feat = self.norms[ii](feat)
            g = gamma_vec[ii].view(n_frames, -1, 1, 1).expand_as(feat)
            b = beta_vec[ii].view(n_frames, -1, 1, 1).expand_as(feat)
            feat = F.relu(g * feat + b)
            visu_feat.append(feat)

        # prior to modeling context, we tackle the scale problem for modeling context by simply sum them.
        joint_feats = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)]

        upsample_lower_feat = F.interpolate(joint_feats[0], size=(joint_feats[1].shape[-2], joint_feats[1].shape[-1]), mode='nearest')
        downsample_higher_feat = self.downsample_conv(joint_feats[-1])

        inter_feat = self.refine((joint_feats[1] + upsample_lower_feat + downsample_higher_feat) / 3.)

        return inter_feat

class PhraseAttention(nn.Module):
  def __init__(self, input_dim):
    super(PhraseAttention, self).__init__()
    # initialize pivot
    self.fc = nn.Linear(input_dim, 1)

  def forward(self,embedded):
    """
    Inputs:
    - embedded: Variable float (batch, seq_len, input_dim)
    Outputs:
    - attn    : Variable float (batch, seq_len)
    - weighted_emb: Variable float (batch, word_vec_size)
    """
    cxt_scores = self.fc(embedded).squeeze(2) # (batch, seq_len)
    attn = F.softmax(cxt_scores,dim=1)  # (batch, seq_len), attn.sum(1) = 1.
    attn = attn / attn.sum(1).view(attn.size(0), 1).expand(attn.size(0), attn.size(1)) # (batch, seq_len)

    # compute weighted embedding
    attn3 = attn.unsqueeze(1)     # (batch, 1, seq_len)
    weighted_emb = torch.bmm(attn3, embedded) # (batch, 1, input_dim)
    weighted_emb = weighted_emb.squeeze(1)    # (batch, input_dim)

    return weighted_emb


class SAVGDETR(nn.Module):
    """
    The general pipeline of SAVGDETR for spatial aerial video grounding, It consists of
    the following several parts:
    - visual encoder
    - language encoder:
    - spatio-temporal multimodal Interactor
    - Temporal Localizer
    - Spatial Localizer 
    """

    def __init__(self, cfg):
        super(SAVGDETR, self).__init__()
        self.cfg = cfg.clone()
        self.max_video_len = cfg.INPUT.MAX_VIDEO_LEN
        self.use_attn = cfg.SOLVER.USE_ATTN
        
        self.use_aux_loss = cfg.SOLVER.USE_AUX_LOSS  # use the output of each transformer layer
        self.use_actioness = cfg.MODEL.SAVGDETR.USE_ACTION
        self.query_dim = cfg.MODEL.SAVGDETR.QUERY_DIM

        self.vis_encoder = build_vis_encoder(cfg)
        vis_fea_dim = self.vis_encoder.num_channels
        self.text_encoder = build_text_encoder(cfg)
        
        self.ground_encoder = build_encoder(cfg)
        self.ground_decoder = build_decoder(cfg)
        
        hidden_dim = cfg.MODEL.SAVGDETR.HIDDEN
        self.hidden_dim = hidden_dim

        self.input_proj = nn.Conv2d(vis_fea_dim, hidden_dim, kernel_size=1)

        self.in_channels = [512, 1024, 2048]
        self.inputs4_proj = nn.Conv2d(self.in_channels[1], hidden_dim, kernel_size=1)
        self.inputs3_proj = nn.Conv2d(self.in_channels[0], hidden_dim, kernel_size=1)
        # self.inputs4_proj = nn.Sequential(
        #         nn.Conv2d(self.in_channels[1], hidden_dim, kernel_size=1, bias=False),
        #         nn.BatchNorm2d(hidden_dim)
        #     )
        # self.inputs3_proj = nn.Sequential(
        #         nn.Conv2d(self.in_channels[0], hidden_dim, kernel_size=1, bias=False),
        #         nn.BatchNorm2d(hidden_dim)
        #     )
        act = 'silu'
        depth_mult = 1.0
        expansion = 1.0
        # # top-down fpn
        # self.lateral_convs = nn.ModuleList()
        # self.fpn_blocks = nn.ModuleList()
        # for _ in range(len(self.in_channels) - 1, 0, -1):
        #     self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
        #     self.fpn_blocks.append(
        #         CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
        #     )
        # # bottom-up pan
        # self.downsample_convs = nn.ModuleList()
        # self.pan_blocks = nn.ModuleList()
        # # for _ in range(len(self.in_channels) - 1):
        # self.downsample_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act))
        # self.pan_blocks.append(
        #     CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
        # )
        self.encoder_Vision_aug_Text = Vision_aug_Text()
        # self.encoder_Text_aug_Vision = Text_aug_Vision()

        # self.TextaugVision = TextaugVision()

        self.lateral_convs = nn.ModuleList()
        self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
        self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
        self.fpn_block = CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
        self.MSFF_FE = MSFFFE(dim=hidden_dim)

        self.downsample_convs = nn.ModuleList()
        self.downsample_convs.append(FrequencyFocusedDownSampling(hidden_dim,hidden_dim))
        self.downsample_convs.append(FrequencyFocusedDownSampling(hidden_dim,hidden_dim))
        self.pan_blocks = nn.ModuleList()
        self.pan_blocks.append(CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion))
        self.pan_blocks.append(CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion))
        self.SAC = SemanticAlignmenCalibration([hidden_dim,hidden_dim])


        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)

        # add the iteration anchor update
        self.ground_decoder.decoder.bbox_embed = self.bbox_embed

    def forward(self, videos, texts, targets, iteration_rate=-1):
        """
        Arguments:
            videos  (NestedTensor): N * C * H * W, N = sum(T) 
            durations : batch video length
            texts   (NestedTensor]): 
            targets (list[TargetTensor]): ground-truth
        Returns: 
        """
        # Visual Feature
        # vis_outputs, vis_pos_embed = self.vis_encoder(videos)
        features, pos = self.vis_encoder(videos)  # each frame from each video is forwarded through the backbone
        vis_outputs, vis_pos_embed = features[-1], pos[-1]

        vis_features, vis_mask, vis_durations = vis_outputs.decompose()
        vis_features = self.input_proj(vis_features)
        vis_outputs = NestedTensor(vis_features, vis_mask, vis_durations)

        # Multi-Scale Visual Feature
        viss4_outputs, viss4_pos_embed = features[-2], pos[-2]
        viss3_outputs, viss3_pos_embed = features[-3], pos[-3]
        viss4_features, viss4_mask, viss4_durations = viss4_outputs.decompose()
        viss3_features, viss3_mask, viss3_durations = viss3_outputs.decompose()
        viss4_features = self.inputs4_proj(viss4_features)
        viss3_features = self.inputs3_proj(viss3_features)
        # viss4_outputs = NestedTensor(viss4_features, viss4_mask, viss4_durations)
        # viss3_outputs = NestedTensor(viss3_features, viss3_mask, viss3_durations)


        # Textual Feature Encoding
        device = vis_features.device
        text_outputs, text_cls = self.text_encoder(texts, device)  # text_cls : [b, d_model]


        # Multimodal Feature Encoding
        encoded_memory = self.ground_encoder(videos=vis_outputs,
                                             vis_pos=vis_pos_embed, texts=text_outputs,)
        # encoded_memory = self.ground_encoder(videos=vis_outputs, videoss4=viss4_outputs,
        #                                 vis_pos=vis_pos_embed, vis_poss4=viss4_pos_embed,
        #                                      texts=text_outputs,  )
        # encoded_memory = self.ground_encoder(videos=vis_outputs, videoss4=viss4_outputs, videoss3=viss3_outputs,
        #                                 vis_pos=vis_pos_embed, vis_poss4=viss4_pos_embed, vis_poss3=viss3_pos_embed,
        #                                      texts=text_outputs,  )


        # Cross-Scale Feature Encoding and fusion
        imgtext_memory = encoded_memory["encoded_memory"]  # n_token x n_frames  x  d_model
        fea_map_size = encoded_memory["fea_map_size"]  # (H,W) the feature map size
        h, w = fea_map_size[0], fea_map_size[1]
        n_vis_tokens = h * w
        text_memory = imgtext_memory[n_vis_tokens:]
        viss5_memory = imgtext_memory[:n_vis_tokens].permute(1, 2, 0).reshape(-1, self.hidden_dim, h, w).contiguous()

        # proj_feats = [viss3_features, viss4_features, viss5_memory]
        # inner_outs = [proj_feats[-1]]
        # for idx in range(len(self.in_channels) - 1, 0, -1):
        #     feat_high = inner_outs[0]
        #     feat_low = proj_feats[idx - 1]
        #     feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)
        #     inner_outs[0] = feat_high
        #     upsample_feat = F.interpolate(feat_high, size=(feat_low.shape[-2],feat_low.shape[-1]), mode='nearest')
        #     inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
        #     inner_outs.insert(0, inner_out)
        # outs = [inner_outs[0]]
        # # for idx in range(len(self.in_channels) - 1):
        # idx = 0
        # feat_low = outs[-1]
        # feat_high = inner_outs[idx + 1]
        # downsample_feat = self.downsample_convs[idx](feat_low)
        # out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))
        # outs.append(out)
        # outs.append(inner_outs[-1])

        inner_outs = []
        feat_high = self.lateral_convs[0](viss5_memory)
        inner_outs.append(feat_high)
        feat_low = viss4_features
        upsample_feat = F.interpolate(feat_high, size=(feat_low.shape[-2], feat_low.shape[-1]), mode='nearest')
        inner_s4 = self.fpn_block(torch.concat([upsample_feat, feat_low], dim=1))
        inner_s4_feat = self.lateral_convs[1](inner_s4)
        # inner_s4_out = self.MSFF_FE(inner_s4_feat)
        inner_s3_out = self.MSFF_FE(viss3_features)
        downsample_s3_feat = self.downsample_convs[0](inner_s3_out)
        out_s4 = self.pan_blocks[0](torch.concat([downsample_s3_feat, inner_s4_feat], dim=1))
        downsample_s4_feat = self.downsample_convs[1](out_s4)
        out_s5 = self.pan_blocks[1](torch.concat([downsample_s4_feat, feat_high], dim=1))
        out_s3 = self.SAC([inner_s3_out,out_s5])

        outs = [out_s3, out_s4, out_s5]
        MS_vis_memory = []
        spatial_shapes = []
        multi_text_feats = []
        for i, feat in enumerate(outs):
            batch_size, c, h, w = feat.shape
            img_feat = feat.flatten(2).permute(2, 0, 1) # [hw, T, d]
            # if i != 2:
                # aug_text = self.encoder_Vision_aug_Text(img_feat, text_memory)
                # multi_text_feats.append(aug_text.permute(1, 0, 2))  # N, BT, D

            aug_text = self.encoder_Vision_aug_Text(img_feat, text_memory)
            multi_text_feats.append(aug_text.permute(1, 0, 2))  # N, BT, D

            # aug_visual = self.encoder_Text_aug_Vision(text_memory, img_feat) # HW, BT, D
            # weight, aug_visual = self.TextaugVision(text_memory, img_feat)

            MS_vis_memory.append(img_feat)
            spatial_shapes.append([h, w])  # [num_levels, 2]
        # multi_text_feats.append(text_memory)

        # MS_vis_memory = torch.concat(MS_vis_memory, 0)  # [l, b, c]
        encoded_memory["multiscale_vis"] = MS_vis_memory  # [s3,s4,s5]
        viss3_mask[:, 0, 0] = False  # avoid empty masks
        viss4_mask[:, 0, 0] = False  # avoid empty masks
        # encoded_memory["multiscale_mask"] = torch.cat([viss3_mask.flatten(1), viss4_mask.flatten(1), encoded_memory["vis_mask"]], dim=1)
        encoded_memory["multiscale_mask"] = [viss3_mask.flatten(1),viss4_mask.flatten(1),encoded_memory["vis_mask"]]
        encoded_memory["multiscale_shapes"] = spatial_shapes
        encoded_memory["multi_text_feats"] = multi_text_feats  # [s3-low,s4-mid,s5-high]
        encoded_memory["multiscale_pos"] = [viss3_pos_embed.flatten(2).permute(2, 0, 1),
                                   viss4_pos_embed.flatten(2).permute(2, 0, 1),
                                   vis_pos_embed.flatten(2).permute(2, 0, 1)]

        # Query-based decoding
        # outputs = self.ground_decoder( memory_cache=encoded_memory, vis_pos=vis_pos_embed, )
        outputs, img_weights = self.ground_decoder(
            memory_cache=encoded_memory,
            vis_pos=vis_pos_embed, vis_poss4=viss4_pos_embed,vis_poss3=viss3_pos_embed,
        )
        # img_weights [[BT, 1, HW],[BT, 1, HW],[BT, 1, HW],[BT, 1, HW],[BT, 1, HW],[BT, 1, HW]]
        # outputs = self.ground_decoder(
        #     memory_cache=encoded_memory,
        #     vis_pos=vis_pos_embed, vis_poss4=viss4_pos_embed, vis_poss3=viss3_pos_embed,
        #     texts=text_outputs, query_embed= query_embed,
        # )

        out = {}
        if self.use_attn:
            out["imgweights"] = img_weights[-1]
            out["spatial_shapes"] = spatial_shapes

        # the final decoder embeddings and the refer anchors
        hs, reference = outputs  # hs : [num_layers, b, T, d_model], reference : [num_layers, b, T, 4]

        ###############  predict bounding box ################
        reference_before_sigmoid = inverse_sigmoid(reference)
        tmp = self.bbox_embed(hs)
        tmp[..., :self.query_dim] += reference_before_sigmoid
        outputs_coord = tmp.sigmoid()  # [num_layers, b, T, 4]
        outputs_coord = outputs_coord.flatten(1, 2)
        out.update({"pred_boxes": outputs_coord[-1]})
        #######################################################

        # auxiliary outputs
        if self.use_aux_loss:
            out["aux_outputs"] = [{"pred_boxes": b}for b in outputs_coord[:-1]]
            for i_aux in range(len(out["aux_outputs"])):
                if self.use_attn:
                    out["aux_outputs"][i_aux]["imgweights"] = img_weights[i_aux]
                    out["aux_outputs"][i_aux]["spatial_shapes"] = spatial_shapes

        return out

class TextaugVision(nn.Module):
    def __init__(self,
                 text_channels: int = 256,
                 embed_channels: int = 256,
                 ):
        super().__init__()
        self.text_channels = text_channels
        self.embed_channels = embed_channels

        self.img_proj = nn.Linear(self.embed_channels, self.text_channels)
        self.text_fc = nn.Linear(self.text_channels, self.embed_channels, bias=False)

        self.fp = nn.Sequential(
            nn.Linear(self.text_channels, self.text_channels),
            nn.LayerNorm(self.text_channels),
            nn.GELU(),
            nn.Linear(self.text_channels, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, txt_feats, img_feat):
        HW, B, _ = img_feat.shape

        img_feat_tmp = self.img_proj(img_feat)  # HW, BT, D
        # img_feat_tmp = img_feat_tmp.permute(0, 2, 3, 1).reshape(B, H * W, -1)
        img_feat_tmp = img_feat_tmp.permute(1, 0, 2)  # BT, HW, D
        txt_feat = self.text_fc(txt_feats)
        txt_feat = txt_feat.permute(1, 0, 2)  # BT, N, D

        q = img_feat_tmp  # BT, HW, D
        k = txt_feat.permute(0, 2, 1)  # BT, D, N
        attn_weight = torch.matmul(q, k)   # BT, HW, N

        # 计算每个词的注意力权重
        alpha = self.fp(txt_feats.permute(1, 0, 2))       # BT, N, 1
        alpha = F.softmax(alpha, dim=1)  # BT, N, 1
        # 沿词维度进行加权池化
        e = torch.sum(alpha * attn_weight.permute(0, 2, 1), dim=1).unsqueeze(1)  # BT, 1, HW
        weight = self.sigmoid(e).permute(0, 2, 1)  # BT, HW, 1

        aug_visual_feat = weight * img_feat.permute(1, 0, 2)  # BT, HW, D

        return weight, aug_visual_feat.permute(1, 0, 2)   # HW, BT, D


class Text_aug_Vision(nn.Module):
    def __init__(self,
                 text_channels: int = 256,
                 embed_channels: int = 256,
                 num_heads: int = 8,
                 ):
        super().__init__()
        self.num_heads = num_heads
        self.text_channels = text_channels
        self.embed_channels = embed_channels

        self.hidden_channel = self.text_channels
        self.head_channels = self.hidden_channel // num_heads

        self.img_proj = nn.Linear(self.embed_channels, self.text_channels)
        # self.text_fc = nn.Linear(self.text_channels, self.embed_channels, bias=False)

        self.query = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                   nn.Linear(self.hidden_channel, self.hidden_channel))
        self.key = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                 nn.Linear(self.hidden_channel, self.hidden_channel))
        self.value = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                   nn.Linear(self.hidden_channel, self.hidden_channel))
        self.proj = nn.Linear(self.hidden_channel, self.text_channels)

    def forward(self, txt_feat, img_feat):
        HW, B, _ = img_feat.shape

        img_feat_tmp = self.img_proj(img_feat)  # HW, BT, D
        # img_feat_tmp = img_feat_tmp.permute(0, 2, 3, 1).reshape(B, H * W, -1)
        img_feat_tmp = img_feat_tmp.permute(1, 0, 2)  # BT, HW, D
        txt_feat = txt_feat.permute(1, 0, 2)  # BT, N, D

        # q = self.query(txt_feat)      # BT, N, D
        # k = self.key(img_feat_tmp)    # BT, HW, D
        # v = self.value(img_feat_tmp)  # BT, HW, D
        q = self.query(img_feat_tmp)  # BT, HW, D
        k = self.key(txt_feat)  # BT, N, D
        v = self.value(txt_feat)  # BT, N, D

        q = q.reshape(B, -1, self.num_heads, self.head_channels)
        k = k.reshape(B, -1, self.num_heads, self.head_channels)
        v = v.reshape(B, -1, self.num_heads, self.head_channels)

        q = q.permute(0, 2, 1, 3)  # BT, 8, HW, D
        k = k.permute(0, 2, 3, 1)  # BT, 8, D, N
        attn_weight = torch.matmul(q, k)   # BT, 8, HW, N
        attn_weight = attn_weight / (self.head_channels ** 0.5)
        attn_weight = F.softmax(attn_weight, dim=-1)

        v = v.permute(0, 2, 1, 3)  # BT, 8, N, D
        aug_v = torch.matmul(attn_weight, v)  # BT, 8, HW, D
        aug_v = aug_v.permute(0, 2, 1, 3).reshape(B, -1, self.hidden_channel)

        aug_visual_feat = self.proj(aug_v)  # BT, HW, D

        return img_feat + aug_visual_feat.permute(1, 0, 2)   # HW, BT, D


class Vision_aug_Text(nn.Module):
    def __init__(self,
                 text_channels: int = 256,
                 embed_channels: int = 256,
                 num_heads: int = 8,
                 ):
        super().__init__()
        self.num_heads = num_heads
        self.text_channels = text_channels
        self.embed_channels = embed_channels

        self.hidden_channel = self.text_channels
        self.head_channels = self.hidden_channel // num_heads

        self.img_proj = nn.Linear(self.embed_channels, self.text_channels)
        # self.text_fc = nn.Linear(self.text_channels, self.embed_channels, bias=False)

        self.query = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                   nn.Linear(self.hidden_channel, self.hidden_channel))
        self.key = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                 nn.Linear(self.hidden_channel, self.hidden_channel))
        self.value = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                   nn.Linear(self.hidden_channel, self.hidden_channel))
        self.proj = nn.Linear(self.hidden_channel, self.text_channels)

    def forward(self, img_feat, txt_feat):
        HW, B, _ = img_feat.shape

        img_feat_tmp = self.img_proj(img_feat)  # HW, BT, D
        # img_feat_tmp = img_feat_tmp.permute(0, 2, 3, 1).reshape(B, H * W, -1)
        img_feat_tmp = img_feat_tmp.permute(1, 0, 2)  # BT, HW, D
        txt_feat = txt_feat.permute(1, 0, 2)  # BT, N, D

        q = self.query(txt_feat)      # BT, N, D
        k = self.key(img_feat_tmp)    # BT, HW, D
        v = self.value(img_feat_tmp)  # BT, HW, D

        q = q.reshape(B, -1, self.num_heads, self.head_channels)
        k = k.reshape(B, -1, self.num_heads, self.head_channels)
        v = v.reshape(B, -1, self.num_heads, self.head_channels)

        q = q.permute(0, 2, 1, 3)  # BT, 8, N, D
        k = k.permute(0, 2, 3, 1)  # BT, 8, D, HW
        attn_weight = torch.matmul(q, k)   # BT, 8, N, HW
        attn_weight = attn_weight / (self.head_channels ** 0.5)
        attn_weight = F.softmax(attn_weight, dim=-1)

        v = v.permute(0, 2, 1, 3)  # BT, 8, HW, D
        aug_v = torch.matmul(attn_weight, v)  # BT, 8, N, D
        aug_v = aug_v.permute(0, 2, 1, 3).reshape(B, -1, self.hidden_channel)

        aug_text_feat = self.proj(aug_v)  # BT, N, D

        return txt_feat + aug_text_feat


def interframe_correspondence(encoded_memory, bboxs):
    imgtext_memory = encoded_memory["encoded_memory"]
    fea_map_size = encoded_memory["fea_map_size"]  # (H,W) the feature map size
    n_vis_tokens = fea_map_size[0] * fea_map_size[1]


    vis_feature_map = imgtext_memory.permute(1, 0, 2)[:, :n_vis_tokens]  # [bT, H*W, C]
    vis_feature_map = vis_feature_map.reshape(-1, fea_map_size[0], fea_map_size[1], imgtext_memory.size(-1))  # [bT, H, W, C]
    T, H, W, D = vis_feature_map.shape

    bboxs = box_cxcywh_to_xyxy(bboxs).clamp(min=0).squeeze() * torch.Tensor([fea_map_size[1], fea_map_size[0], fea_map_size[1], fea_map_size[0]]).to(bboxs.device)
    bboxs = torch.stack([(bboxs[:, 0]).round(), (bboxs[:, 1]).round(), (bboxs[:, 2]).ceil(), (bboxs[:, 3]).round()],dim=-1).long().unsqueeze(1)  # torch.round(bboxs).int()

    region_feature = get_roi_feature(bboxs,vis_feature_map) # [T, Npos, D]
    # print("region_feature:", region_feature.shape)

    num_negatives = 5
    neg_bboxs = generate_negative_boxes(T, H, W, num_negatives, bboxs)  #[T, Nneg, 4]
    # print("neg_bboxs:", neg_bboxs.shape)

    neg_region_feature = get_roi_feature(neg_bboxs, vis_feature_map)  #[T, Nneg, D]
    # print("neg_region_feature:", neg_region_feature.shape)

    frame_feature, corre_feature = region_feature[:-1], region_feature[1:]
    neg_feature = neg_region_feature[:-1]
    return frame_feature, corre_feature, neg_feature, region_feature, neg_region_feature


def get_roi_feature(bboxs, vis_feature_map):
    roi_feature = []
    for i in range(len(bboxs)):
        tem_roi = []
        for n in range(len(bboxs[i])):
            f = vis_feature_map[i].clone()
            x1, y1, x2, y2 = bboxs[i][n]
            x2 = min(max(x2, 1), f.size(1))
            x1 = min(max(x1, 0), x2 - 1)
            y2 = min(max(y2, 1), f.size(0))
            y1 = min(max(y1, 0), y2 - 1)
            r = f[y1:y2, x1:x2].clone().reshape(-1, vis_feature_map.size(-1))
            pooling_r = torch.mean(r, dim=0)
            tem_roi.append(pooling_r)
        roi_feature.append(torch.stack(tem_roi))
    return torch.stack(roi_feature)


def generate_negative_boxes(T, H, W, num_negatives, true_boxes):
    # 初始化负样本框数组
    neg_boxes = torch.zeros((T, num_negatives, 4), dtype=torch.int)

    for j in range(num_negatives):
        # 随机生成负样本框
        x1_neg = torch.randint(0, W, (1,)).to(true_boxes.device)[0]
        y1_neg = torch.randint(0, H, (1,)).to(true_boxes.device)[0]
        x2_neg = torch.randint(x1_neg.item(), W, (1,)).to(true_boxes.device)[0]
        y2_neg = torch.randint(y1_neg.item(), H, (1,)).to(true_boxes.device)[0]
        # 遍历每一帧生成负样本框
        for i in range(T):
            # 检查负样本框是否与真实框不完全重叠
            x1_true, y1_true, x2_true, y2_true = true_boxes[i][0]
            count = 0
            while True:
                count += 1
                # 确保候选框不完全被真实框包含
                if not (x1_neg.item() >= x1_true.item() and y1_neg.item() >= y1_true.item() and x2_neg.item() <= x2_true.item() and y2_neg.item() <= y2_true.item()):
                    break
                else:
                    # 随机生成负样本框
                    x1_neg = torch.randint(0, W, (1,)).to(true_boxes.device)[0]
                    y1_neg = torch.randint(0, H, (1,)).to(true_boxes.device)[0]
                    x2_neg = torch.randint(x1_neg.item(), W, (1,)).to(true_boxes.device)[0]
                    y2_neg = torch.randint(y1_neg.item(), H, (1,)).to(true_boxes.device)[0]
                    if count > 10:
                        if not (x1_neg.item() == x1_true.item() and y1_neg.item() == y1_true.item() and x2_neg.item() == x2_true.item() and y2_neg.item() == y2_true.item()):
                            break

            # 将有效的负样本框添加到数组中
            neg_boxes[i, j] = torch.tensor([x1_neg, y1_neg, x2_neg, y2_neg])

    return neg_boxes
