import copy
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from typing import List, Optional, Tuple
# from mmcv.cnn import Linear, ConvModule
from utils.misc import NestedTensor
from .position_encoding import SeqEmbeddingLearned, SeqEmbeddingSine


class CrossModalEncoder(nn.Module):
    
    def __init__(self, cfg):
        super().__init__()
        # attention configuration
        d_model = cfg.MODEL.SAVGDETR.HIDDEN
        nhead = cfg.MODEL.SAVGDETR.HEADS
        dim_feedforward = cfg.MODEL.SAVGDETR.FFN_DIM
        dropout = cfg.MODEL.SAVGDETR.DROPOUT
        activation = "relu"
        num_layers = cfg.MODEL.SAVGDETR.ENC_LAYERS
        self.d_model = d_model
        
        encoder_layer = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation
        )
        encoder_norm = None
        self.encoder = SpatialTemporalEncoder(cfg, encoder_layer, num_layers, encoder_norm)

        # self.Vision_aug_Text = Vision_aug_Text()

        # The position embedding for feature map
        self._reset_parameters()
        
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def forward(self, videos : NestedTensor = None, videoss4 : NestedTensor = None, videoss3 : NestedTensor = None,
                vis_pos = None, vis_poss4 = None, vis_poss3 = None, texts : Tuple = None):
        pos_embeds5 = vis_pos   # The sin pos embeding from backbone encoder
        viss5_features, viss5_mask, viss5_durations = videos.decompose()
        # viss4_features, viss4_mask, viss4_durations = videoss4.decompose()
        # viss3_features, viss3_mask, viss3_durations = videoss3.decompose()

        assert pos_embeds5.shape[0] == sum(viss5_durations)

        viss5_mask[:, 0, 0] = False  # avoid empty masks
        # viss4_mask[:, 0, 0] = False  # avoid empty masks
        # viss3_mask[:, 0, 0] = False  # avoid empty masks

        b = len(viss5_durations)
        _, _, H5, W5 = viss5_features.shape
        # _, _, H4, W4 = viss4_features.shape
        # _, _, H3, W3 = viss3_features.shape
        # spatial_shapes = [(H5, W5),(H4, W4),(H3, W3),]
        # n_frames x c x h x w => hw x n_frames x c
        viss5_features = viss5_features.flatten(2).permute(2, 0, 1)
        # viss4_features = viss4_features.flatten(2).permute(2, 0, 1)
        # viss3_features = viss3_features.flatten(2).permute(2, 0, 1)
        # vis_features = torch.cat([viss5_features, viss4_features, viss3_features], dim=0)
        # vis_features = torch.cat([viss5_features, viss4_features], dim=0)
        vis_features = viss5_features

        pos_embeds5 = vis_pos.flatten(2).permute(2, 0, 1)
        # pos_embeds4 = vis_poss4.flatten(2).permute(2, 0, 1)
        # pos_embeds3 = vis_poss3.flatten(2).permute(2, 0, 1)
        # pos_embed = torch.cat([pos_embeds5, pos_embeds4, pos_embeds3], dim=0)
        # pos_embed = torch.cat([pos_embeds5, pos_embeds4], dim=0)
        pos_embed = pos_embeds5

        vis_masks5 = viss5_mask.flatten(1)
        # vis_masks4 = viss4_mask.flatten(1)
        # vis_masks3 = viss3_mask.flatten(1)
        # vis_mask = torch.cat([vis_masks5, vis_masks4, vis_masks3], dim=1)
        # vis_mask = torch.cat([vis_masks5, vis_masks4], dim=1)
        vis_mask = vis_masks5

        # prepare the text encodings
        text_attention_mask, text_memory_resized, _ = texts
        
        # expand the attention mask from [b, len] to [n_frames, len]
        text_mask_list = []
        for i_b in range(b):
            frame_length = viss5_durations[i_b]
            text_mask_list.append(
                torch.stack([text_attention_mask[i_b] for _ in range(frame_length)])
            )
        text_attention_mask = torch.cat(text_mask_list)
        
        # expand the text token from [len, b, d_model] to [len, n_frames, d_model]
        text_fea_list = []
        for i_b in range(b):
            frame_length = viss5_durations[i_b]
            text_fea_list.append(
                torch.stack([text_memory_resized[:, i_b] for _ in range(frame_length)],dim=1)
            )
        text_memory_resized = torch.cat(text_fea_list, dim=1)   # [Nt, T, D]

        # concat visual and text features and Pad the pos_embed with 0 for the text tokens
        features = torch.cat([vis_features, text_memory_resized], dim=0)
        mask = torch.cat([vis_mask, text_attention_mask], dim=1)
        pos_embed = torch.cat([pos_embed, torch.zeros_like(text_memory_resized)], dim=0)

        # perfrom cross-modality interaction
        img_memory, frames_cls, videos_cls = self.encoder(
            features,
            src_key_padding_mask=mask,
            vis_pos=pos_embed,
            durations=viss5_durations,
            fea_map_size = (H5, W5),
            # vis_features,   # [Nv, T, D]
            # src_key_padding_mask=vis_mask,  # [T, Nv]
            # vis_pos=pos_embed,  # [Nv, T, D]
            # text_memory = text_memory_resized,   # [Nt, T, D]
            # text_mask = text_attention_mask,     # [T, Nt]
        )

        # text_memory = self.Vision_aug_Text(img_memory,H5, W5, text_memory)
        # print("Vision_aug_Text:",text_memory_resized.shape)
        memory_cache = {
            "encoded_memory": img_memory,
            "mask": mask,  # batch first
            "frames_cls": frames_cls,  # n_frame, d_model
            "videos_cls": videos_cls,  # b , d_model
            "durations": viss5_durations,
            "fea_map_size": (H5, W5),
            "text_mask": text_attention_mask,
            "vis_mask":vis_mask,
        }
        # memory_cache = {
        #     "encoded_memory": img_memory,
        #     "text_memory": text_memory,
        #     "text_mask": text_attention_mask,
        #     "mask": vis_mask,  # batch first
        #     "durations": viss5_durations,
        #     # "fea_map_size" : spatial_shapes,
        # }
        
        return memory_cache



class Vision_aug_Text(nn.Module):
    def __init__(self, embed_channels=256,num_heads= 8, ):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_channel = embed_channels
        self.head_channels = self.hidden_channel // num_heads

        self.img_proj = nn.Sequential(
            ConvModule(in_channels=self.hidden_channel,
                       out_channels=self.hidden_channel * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       ),
            ConvModule(in_channels=self.hidden_channel * 2,
                       out_channels=self.hidden_channel,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       ),
            nn.Conv2d(in_channels=self.hidden_channel,
                      out_channels=self.hidden_channel,
                      kernel_size=1))

        self.query = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                   nn.Linear(self.hidden_channel, self.hidden_channel))
        self.key = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                 nn.Linear(self.hidden_channel, self.hidden_channel))
        self.value = nn.Sequential(nn.LayerNorm(self.hidden_channel),
                                   nn.Linear(self.hidden_channel, self.hidden_channel))

        self.proj = nn.Linear(self.hidden_channel, self.hidden_channel)


    def forward(self, img_feat: Tensor, H, W, txt_feat: Tensor) -> Tensor:
        HW, B, D = img_feat.shape
        img_feat = img_feat.permute(1, 2, 0).unsqueeze(-1).reshape(B, -1, H, W)
        img_feat_tmp = self.img_proj(img_feat)
        img_feat_tmp = img_feat_tmp.permute(0, 2, 3, 1).reshape(B, H * W, -1)

        txt_feat = txt_feat.permute(1, 0, 2)  # [T, Nt, D]

        q = self.query(txt_feat)
        k = self.key(img_feat_tmp)
        v = self.value(img_feat_tmp)

        q = q.reshape(B, -1, self.num_heads, self.head_channels)
        k = k.reshape(B, -1, self.num_heads, self.head_channels)
        v = v.reshape(B, -1, self.num_heads, self.head_channels)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 3, 1)
        attn_weight = torch.matmul(q, k)
        attn_weight = attn_weight / (self.head_channels ** 0.5)
        attn_weight = F.softmax(attn_weight, dim=-1)

        v = v.permute(0, 2, 1, 3)
        aug_v = torch.matmul(attn_weight, v)
        aug_v = aug_v.permute(0, 2, 1, 3).reshape(B, -1, self.hidden_channel)
        aug_text_feat = self.proj(aug_v)


        return (txt_feat + aug_text_feat).permute(1, 0, 2)


class TextaugVision(nn.Module):
    def __init__(self,
                 text_channels: int = 256,
                 embed_channels: int = 256,
                 ):
        super().__init__()
        self.text_channels = text_channels
        self.embed_channels = embed_channels

        self.img_proj = nn.Linear(self.embed_channels, self.text_channels)
        self.text_fc = nn.Linear(self.text_channels, self.embed_channels, bias=False)

        self.fp = nn.Sequential(
            nn.Linear(self.text_channels, self.text_channels),
            nn.LayerNorm(self.text_channels),
            nn.GELU(),
            nn.Linear(self.text_channels, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, txt_feats, img_feat):
        HW, B, _ = img_feat.shape

        img_feat_tmp = self.img_proj(img_feat)  # HW, BT, D
        # img_feat_tmp = img_feat_tmp.permute(0, 2, 3, 1).reshape(B, H * W, -1)
        img_feat_tmp = img_feat_tmp.permute(1, 0, 2)  # BT, HW, D
        txt_feat = self.text_fc(txt_feats)
        txt_feat = txt_feat.permute(1, 0, 2)  # BT, N, D

        q = img_feat_tmp  # BT, HW, D
        k = txt_feat.permute(0, 2, 1)  # BT, D, N
        attn_weight = torch.matmul(q, k)   # BT, HW, N

        # 计算每个词的注意力权重
        alpha = self.fp(txt_feats.permute(1, 0, 2))       # BT, N, 1
        alpha = F.softmax(alpha, dim=1)  # BT, N, 1
        # 沿词维度进行加权池化
        e = torch.sum(alpha * attn_weight.permute(0, 2, 1), dim=1).unsqueeze(1)  # BT, 1, HW
        weight = self.sigmoid(e).permute(0, 2, 1)  # BT, HW, 1

        aug_visual_feat = weight * img_feat.permute(1, 0, 2)  # BT, HW, D

        return weight, aug_visual_feat.permute(1, 0, 2)   # HW, BT, D

class SpatialTemporalEncoder(nn.Module):
    def __init__(self, cfg, encoder_layer, num_layers, norm=None, return_weights=False):
        super().__init__()
        self.spatial_layers = _get_clones(encoder_layer, num_layers)
        self.temporal_layers = _get_clones(encoder_layer, num_layers)
        video_max_len = cfg.INPUT.MAX_VIDEO_LEN
        d_model = cfg.MODEL.SAVGDETR.HIDDEN
        self.d_model = d_model
        
        # The position embedding of global tokens
        if cfg.MODEL.SAVGDETR.USE_LEARN_TIME_EMBED:
            self.time_embed = SeqEmbeddingLearned(video_max_len + 1 , d_model)
        else:
            self.time_embed = SeqEmbeddingSine(video_max_len + 1, d_model) 
    
        # The position embedding of local frame tokens
        self.local_pos_embed = nn.Embedding(1, d_model) # the learned pos embed for frame cls token
        
        # The learnd local and global embedding
        self.frame_cls = nn.Embedding(1, d_model)  # the frame level local cls token
        self.video_cls = nn.Embedding(1, d_model)  # the video level global cls token
        
        self.num_layers = num_layers
        self.norm = norm
        self.return_weights = return_weights

        # self.TextaugVision = TextaugVision()

    def forward(
        self,
        src,
        mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        vis_pos: Optional[Tensor] = None,
        durations=None,
        fea_map_size = None,
        text_memory = None,
        text_mask = None,
    ):
        output = src
        b = len(durations)
        t = max(durations)
        n_frames = sum(durations)
        n_vis_tokens = fea_map_size[0] * fea_map_size[1]

        device = output.device
        
        # The position embedding, token mask, src feature for local frame token, in spatial layer
        frame_src = self.frame_cls.weight.unsqueeze(1).repeat(1, n_frames, 1) # 1 x n_frames X d_model
        frame_pos = self.local_pos_embed.weight.unsqueeze(1).repeat(1, n_frames, 1) # 1 x n_frames X d_model
        frame_mask = torch.zeros((n_frames,1)).bool().to(device)
        
        output = torch.cat([frame_src, output], dim=0)  # [1,T,D] + [Nv+Nt, T, D]
        src_key_padding_mask = torch.cat([frame_mask, src_key_padding_mask],dim=1)
        pos = torch.cat([frame_pos, vis_pos],dim=0)
        
        # The position embedding, token mask, in temporal layer
        video_src = self.video_cls.weight.unsqueeze(0).repeat(b, 1, 1)  # b x 1 x d_model
        temp_pos = self.time_embed(t + 1).repeat(1, b, 1)  # (T + 1) x b x d_model
        temp_mask = torch.ones(b, t + 1).bool().to(device)
        temp_mask[:, 0] = False       # the mask for the video cls token
        for i_dur, dur in enumerate(durations):
            temp_mask[i_dur, 1 : 1 + dur] = False

        
        for i_layer, layer in enumerate(self.spatial_layers):
            # spatial interaction on each single frame
            output = layer(
                output,    # [N, T, D]
                src_mask=mask,   # None
                src_key_padding_mask=src_key_padding_mask,  # [T, N]
                pos=pos,   # [N, T, D]
            )

            frames_src = torch.zeros(b, t+1, self.d_model).to(device)    # b x seq_len x C
            frames_src_list = torch.split(output[0,:,:], durations)  # [(n_frames, C)]
            for i_dur, dur in enumerate(durations):
                frames_src[i_dur, 0 : 1, :] = video_src[i_dur]  # pad the video cls token
                frames_src[i_dur, 1 : 1 + dur, :] = frames_src_list[i_dur]
            frames_src = frames_src.permute(1, 0, 2)  # permute LenxBTxC to BTxLenxC

            # temporal interaction between all video frames
            frames_src = self.temporal_layers[i_layer](
                frames_src,   # [T, N, D]
                src_mask=None,
                src_key_padding_mask=temp_mask,   # [N, T]
                pos=temp_pos  # [T, N, D]
            )

            frames_src = frames_src.permute(1, 0, 2) # permute LenxBxC to BxLenxC
            # dispatch the temporal context to each single frame token
            frames_src_list = []
            for i_dur, dur in enumerate(durations):
                video_src[i_dur] = frames_src[i_dur, 0 : 1]
                frames_src_list.append(frames_src[i_dur, 1 : 1 + dur])  # LenxC

            frames_src = torch.cat(frames_src_list, dim=0)
            output[0,:,:] = frames_src


        if self.norm is not None:
            output = self.norm(output)

        frame_src = output[0,:,:]
        output = output[1:,:,:]
        video_src = video_src.squeeze(1)  # b x 1 x d_model => b x d_model

        # return output, frame_src, None
        return output, frame_src, video_src


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    
    def forward(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
