import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class TextAttentionNeck(nn.Module):
    """
    """
    def __init__(self, text_dim: int, embed_dim: int, num_heads: int, output_size: tuple):
        """
        Args:
        """
        super().__init__()
        self.output_size = output_size
        
        self.text_proj = nn.Linear(text_dim, embed_dim)
        
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

        self.mask_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim // 4, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim // 4, embed_dim // 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(output_size),
            nn.Conv2d(embed_dim // 8, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, z_text, x_text):
        """
        Args:
        """
        z_text_proj = self.text_proj(z_text) # [B, Nzt, 256]
        x_text_proj = self.text_proj(x_text) # [B, Nxt, 256]

        text_enhancement, _ = self.cross_attention(query=x_text_proj, key=z_text_proj, value=z_text_proj)
        x_text_fused = self.norm(x_text_proj + text_enhancement) # [B, Nxt, 256]

        bs, hw, c = x_text_fused.shape
        h = w = int(math.sqrt(hw)) # h = w = 28
        x_text_grid = x_text_fused.transpose(1, 2).view(bs, c, h, w)

        text_attention_mask = self.mask_head(x_text_grid) # [B, 1, 16, 16]
        
        return text_attention_mask