import torch
import torch.nn.functional as F
from torch import nn, Tensor

class Text_guide_linear(nn.Module):
    def __init__(self, d_model=256, d_model_visual=256):
        super().__init__()
        self.TC_attention = Attention_layer(d_model, d_model_visual, isCNN=False)

    def forward(self, x, word_feat_info):
        x = self.TC_attention(word_feat_info, x)
        return x

class Text_guide_conv(nn.Module):
    def __init__(self, d_model=256, d_model_visual=256):
        super().__init__()
        self.TC_attention = Attention_layer(d_model, d_model_visual)

    def forward(self, x, word_feat_info):
        x = self.TC_attention(word_feat_info, x)
        return x

class AttentionPooling(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.LayerNorm(in_dim),
            nn.GELU(),
            nn.Linear(in_dim, 1),
        )

    def forward(self, last_hidden_state, attention_mask):
        w = self.attention(last_hidden_state).float()
        attention_mask = attention_mask.permute(1, 0).unsqueeze(2)
        w[attention_mask == 0] = float('-inf')
        w = torch.softmax(w, 0)
        attention_embeddings = torch.sum(w * last_hidden_state, dim=0, keepdim=True)
        return attention_embeddings

class Attention_layer(nn.Module):
    def __init__(self, language_in_channels, visual_in_channels, inter_channels=None, isCNN=True):
        super().__init__()
        self.inter_channels = inter_channels
        if self.inter_channels is None:
            self.inter_channels = 256
            if self.inter_channels == 0:
                self.inter_channels = 1
        self.visual_projection   = nn.Linear(visual_in_channels, self.inter_channels, bias=False)
        self.CNN = isCNN
        self.use_scale = True
        self.pooling = AttentionPooling(self.inter_channels)

    def forward(self, word_feat_info, img_feat):

        if self.CNN:
            # img_feat -> visual_embedding: [B, C, H, W] -> [HxW, B, C_inner]
            B, _, h, w = img_feat.shape
            visual_embedding = self.visual_projection(img_feat.view(B, -1, h*w).permute(2, 0, 1))
        else:
            # img_feat -> visual_embedding: [HxW, B, C] -> [HxW, B, C_inner]
            visual_embedding = self.visual_projection(img_feat)
        # word_feat -> textual_embedding: [L, B, C] -> [L, B, C_inner]
        textual_embedding = word_feat_info[0]
        word_mask = ~word_feat_info[1]
        # pairwise_weight: [B, HxW, 1]
        word_attn_embedding = self.pooling(textual_embedding, word_mask.to(dtype=torch.float32))
        pairwise_weight = (torch.bmm(visual_embedding.permute(1, 0, 2), word_attn_embedding.permute(1, 2, 0))).sigmoid()
        if self.CNN:
            # pairwise_spatial_map (CNN): [B, 1, H, W]
            pairwise_spatial_map = pairwise_weight.permute(0, 2, 1).view(B, 1, h, w)
            img_feat = pairwise_spatial_map*img_feat
            return img_feat, pairwise_spatial_map
        else:
            # pairwise_spatial_map (Transformer): [HxW, B, 1]
            pairwise_spatial_map = pairwise_weight.permute(1, 0, 2)
            img_feat = pairwise_spatial_map*img_feat
            B = img_feat.size(1)
            return img_feat, pairwise_spatial_map.permute(1, 2, 0).view(B, 1, 20, 20)
