"""
Cross-Attention Module
Used for cross-modal fusion in brain signal generation tasks
"""
import torch
import torch.nn as nn


class MultiModalCrossAttention(nn.Module):
    """
    Unidirectional Cross-Attention Module
    
    Used for conditional generation tasks: Q comes from the modality to be generated,
    K/V comes from the conditioning modality
    
    In Image-to-BrainSignal task:
    - Q: Brain signal patches (to be generated based on image condition)
    - K, V: Image embedding (provides conditioning information)
    
    Args:
        dim (int): Hidden dimension
        num_heads (int): Number of attention heads
        dropout (float): Dropout probability
        qkv_bias (bool): Whether to use bias in Q/K/V projections
    
    Usage:
        cross_attn = MultiModalCrossAttention(dim=768, num_heads=12)
        x = torch.randn(32, 3400, 768)       # Brain signal patches
        context = torch.randn(32, 257, 768)  # Image embedding
        output = cross_attn(x, context)      # [32, 3400, 768]
    """

    def __init__(self, dim, num_heads, dropout=0.1, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        
        # Use PyTorch built-in MultiheadAttention
        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            dropout=dropout,
            bias=qkv_bias,
            batch_first=True  # Input format is [B, N, D]
        )

    def forward(self, x, context):
        """
        Unidirectional Cross-Attention forward pass
        
        Args:
            x: [B, N, D] - Query source (e.g., brain signal patches, LayerNorm already applied externally)
            context: [B, M, D] - Key/Value source (e.g., image embedding)
        
        Returns:
            output: [B, N, D] - Same shape as input x
        """
        # Cross-attention: Q from x, K/V from context
        # Note: LayerNorm for x is already done in DiTBlockXAttention
        attn_output, _ = self.attn(
            query=x,
            key=context,
            value=context
        )
        
        return attn_output


class BidirectionalCrossAttention(nn.Module):
    """
    Bidirectional Cross-Attention Module (optional, for scenarios requiring bidirectional information exchange)
    
    Simultaneously computes:
    1. x attending to context (x queries context)
    2. context attending to x (context queries x)
    
    Note: Requires x and context to have the same sequence length for concatenation
    
    Args:
        dim (int): Hidden dimension
        num_heads (int): Number of attention heads
        dropout (float): Dropout probability
    """

    def __init__(self, dim, num_heads, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        
        # x -> context direction attention
        self.attn_x2c = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # context -> x direction attention
        self.attn_c2x = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Layer norms
        self.norm_x = nn.LayerNorm(dim)
        self.norm_c = nn.LayerNorm(dim)
        
        # Output projection (fuse results from both directions)
        self.linear_out = nn.Linear(2 * dim, dim)

    def forward(self, x, context):
        """
        Bidirectional Cross-Attention forward pass
        
        Args:
            x: [B, N, D]
            context: [B, N, D] - Note: Sequence length must be the same as x
        
        Returns:
            output: [B, N, D]
        """
        # Normalize
        x_norm = self.norm_x(x)
        c_norm = self.norm_c(context)
        
        # x attending to context (Q from x, K/V from context)
        out_x2c, _ = self.attn_x2c(query=x_norm, key=c_norm, value=c_norm)
        
        # context attending to x (Q from context, K/V from x)
        out_c2x, _ = self.attn_c2x(query=c_norm, key=x_norm, value=x_norm)
        
        # Concatenate and project
        output = torch.cat([out_x2c, out_c2x], dim=-1)
        output = self.linear_out(output)
        
        return output

