import torch.nn as nn

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        
    def forward(self, query, key_value, mask=None):
        attn_out, attn_weights = self.attn(query, key_value, key_value, key_padding_mask=mask)
        out = self.norm(query + attn_out)
        out = out + self.ff(out)
        return out, attn_weights