import torch
import torch.nn as nn


class CrossModalAttention(nn.Module):
    def __init__(self, input_dim):
        super(CrossModalAttention, self).__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1, x2):
        # x1, x2, x3 are the inputs from different modalities with shape [B, 1, 4096]
        query1, key1, value1 = self.query(x1), self.key(x1), self.value(x1)
        query2, key2, value2 = self.query(x2), self.key(x2), self.value(x2)
        
        # Attention between modalities
        attn12 = self.softmax(torch.matmul(query1, key2.transpose(-1, -2)))
        attn21 = self.softmax(torch.matmul(query2, key1.transpose(-1, -2)))
        
        # Cross-modal interaction
        x1_interacted = torch.matmul(attn12, value2) + torch.matmul(attn12, value2)
        x2_interacted = torch.matmul(attn21, value1) + torch.matmul(attn21, value1)
        
        return x1_interacted, x2_interacted
   