"""
Binary causal mask utilities for the LazyAttention model
"""
import torch


def create_binary_causal_mask(
    batch_size: int,
    seq_length: int,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
    sliding_window: int = None
) -> torch.Tensor:
    """
    Create a binary causal mask where 1 means "can attend" and 0 means "cannot attend"
    
    Args:
        batch_size: Batch size
        seq_length: Sequence length
        device: Device to place the mask on
        dtype: Data type of the mask
        sliding_window: If specified, only allow attention within this window
        
    Returns:
        Binary mask of shape (batch_size, 1, seq_length, seq_length)
        where mask[b, 0, i, j] = 1 if position i can attend to position j
    """
    # Create causal mask: lower triangular matrix
    # mask[i, j] = 1 if i >= j (can see previous and current positions)
    mask = torch.tril(torch.ones(seq_length, seq_length, device=device, dtype=dtype))
    
    # Apply sliding window if specified
    if sliding_window is not None and sliding_window > 0:
        # Only allow attention within the window
        # mask[i, j] = 1 if i >= j AND i - j < sliding_window
        for i in range(seq_length):
            for j in range(seq_length):
                if i - j >= sliding_window:
                    mask[i, j] = 0
    
    # Expand to batch dimension: (seq_length, seq_length) -> (batch_size, 1, seq_length, seq_length)
    mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
    
    return mask


def create_binary_attention_mask(
    attention_mask: torch.Tensor,
    batch_size: int,
    seq_length: int,
    device: torch.device,
    dtype: torch.dtype = torch.float32,
    sliding_window: int = None
) -> torch.Tensor:
    """
    Create a combined binary mask from causal mask and optional padding mask
    
    Args:
        attention_mask: Optional 2D padding mask of shape (batch_size, seq_length)
                       where 1 means "valid token" and 0 means "padding"
        batch_size: Batch size
        seq_length: Sequence length
        device: Device to place the mask on
        dtype: Data type of the mask
        sliding_window: If specified, only allow attention within this window
        
    Returns:
        Binary mask of shape (batch_size, 1, seq_length, seq_length)
    """
    # Create causal mask
    causal_mask = create_binary_causal_mask(
        batch_size, seq_length, device, dtype, sliding_window
    )
    
    # If we have a padding mask, combine it with causal mask
    if attention_mask is not None:
        # attention_mask shape: (batch_size, seq_length)
        # Expand to 4D: (batch_size, 1, 1, seq_length)
        padding_mask = attention_mask.unsqueeze(1).unsqueeze(1)
        # Combine: a position can attend only if both causal and padding allow it
        combined_mask = causal_mask * padding_mask
        return combined_mask
    
    return causal_mask
