"""Bottleneck mask utilities for latent-based attention patterns."""

import torch
from torch.nn.attention.flex_attention import (
    BlockMask,
    _mask_mod_signature,
    create_block_mask,
    or_masks,
)


def latent_cross_attention_mask_factory(
    context_len: int,
    latent_len: int
) -> _mask_mod_signature:
    """Latents can only attend to context tokens.
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        
    Returns:
        Mask function for latent-to-context cross attention
    """
    def latent_cross_attention_mask(b, h, q_idx, kv_idx):
        # Define regions
        latent_start = context_len
        latent_end = context_len + latent_len
        
        # Query must be in latent region
        q_in_latent = (q_idx >= latent_start) & (q_idx < latent_end)
        
        # Key/value must be in context region
        kv_in_context = kv_idx < context_len
        
        return q_in_latent & kv_in_context
    
    latent_cross_attention_mask.__name__ = f"latent_cross_{context_len}_{latent_len}"
    return latent_cross_attention_mask


def latent_self_attention_mask_factory(
    context_len: int,
    latent_len: int,
    causal: bool = False
) -> _mask_mod_signature:
    """Latents attend to each other (optionally causal).
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        causal: Whether latent self-attention should be causal
        
    Returns:
        Mask function for latent self-attention
    """
    def latent_self_attention_mask(b, h, q_idx, kv_idx):
        # Define latent region
        latent_start = context_len
        latent_end = context_len + latent_len
        
        # Both query and key must be in latent region
        q_in_latent = (q_idx >= latent_start) & (q_idx < latent_end)
        kv_in_latent = (kv_idx >= latent_start) & (kv_idx < latent_end)
        
        if causal:
            # Causal mask within latents
            return q_in_latent & kv_in_latent & (q_idx >= kv_idx)
        else:
            # Full attention within latents
            return q_in_latent & kv_in_latent
    
    mask_type = "causal" if causal else "full"
    latent_self_attention_mask.__name__ = f"latent_self_{mask_type}_{context_len}_{latent_len}"
    return latent_self_attention_mask


def buffer_to_latent_mask_factory(
    context_len: int,
    latent_len: int,
    buffer_len: int
) -> _mask_mod_signature:
    """Buffer tokens attend to latent tokens.
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        buffer_len: Number of buffer tokens
        
    Returns:
        Mask function for buffer-to-latent attention
    """
    def buffer_to_latent_mask(b, h, q_idx, kv_idx):
        # Define regions
        latent_start = context_len
        latent_end = context_len + latent_len
        buffer_start = context_len + latent_len
        buffer_end = context_len + latent_len + buffer_len
        
        # Query in buffer, key in latent
        q_in_buffer = (q_idx >= buffer_start) & (q_idx < buffer_end)
        kv_in_latent = (kv_idx >= latent_start) & (kv_idx < latent_end)
        
        return q_in_buffer & kv_in_latent
    
    buffer_to_latent_mask.__name__ = f"buffer_to_latent_{context_len}_{latent_len}_{buffer_len}"
    return buffer_to_latent_mask


def buffer_causal_mask_factory(
    context_len: int,
    latent_len: int,
    buffer_len: int
) -> _mask_mod_signature:
    """Causal self-attention within buffer region.
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        buffer_len: Number of buffer tokens
        
    Returns:
        Mask function for causal buffer self-attention
    """
    def buffer_causal_mask(b, h, q_idx, kv_idx):
        # Define buffer region
        buffer_start = context_len + latent_len
        buffer_end = context_len + latent_len + buffer_len
        
        # Both in buffer region
        q_in_buffer = (q_idx >= buffer_start) & (q_idx < buffer_end)
        kv_in_buffer = (kv_idx >= buffer_start) & (kv_idx < buffer_end)
        
        # Causal within buffer
        is_causal = q_idx >= kv_idx
        
        return q_in_buffer & kv_in_buffer & is_causal
    
    buffer_causal_mask.__name__ = f"buffer_causal_{context_len}_{latent_len}_{buffer_len}"
    return buffer_causal_mask


def target_to_latent_mask_factory(
    context_len: int,
    latent_len: int,
    buffer_len: int
) -> _mask_mod_signature:
    """All target tokens attend to all latent tokens.
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        buffer_len: Number of buffer tokens
        
    Returns:
        Mask function for target-to-latent attention
    """
    def target_to_latent_mask(b, h, q_idx, kv_idx):
        # Define regions
        latent_start = context_len
        latent_end = context_len + latent_len
        target_start = context_len + latent_len + buffer_len
        
        # Query in target, key in latent
        q_in_target = q_idx >= target_start
        kv_in_latent = (kv_idx >= latent_start) & (kv_idx < latent_end)
        
        return q_in_target & kv_in_latent
    
    target_to_latent_mask.__name__ = f"target_to_latent_{context_len}_{latent_len}_{buffer_len}"
    return target_to_latent_mask


def chunked_target_buffer_mask_factory(
    context_len: int,
    latent_len: int,
    buffer_len: int,
    attending_chunks: int = 4
) -> _mask_mod_signature:
    """First N chunks of targets attend causally to buffer (reusing pattern from masks.py).
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        buffer_len: Number of buffer tokens
        attending_chunks: Number of target chunks that attend to buffer
        
    Returns:
        Mask function for chunked target-buffer attention pattern
    """
    def chunked_target_buffer_mask(b, h, q_idx, kv_idx):
        buffer_start = context_len + latent_len
        buffer_end = context_len + latent_len + buffer_len
        target_start = context_len + latent_len + buffer_len
        
        q_in_target = q_idx >= target_start
        kv_in_buffer = (kv_idx >= buffer_start) & (kv_idx < buffer_end)
        
        # Basic condition
        base_condition = q_in_target & kv_in_buffer
        
        target_offset = q_idx - target_start
        buffer_offset = kv_idx - buffer_start
        
        # First attending_chunks * buffer_len positions attend
        in_attending_region = target_offset < (attending_chunks * buffer_len)
        
        # Causal within chunk
        chunk_position = target_offset % buffer_len
        causal = buffer_offset <= chunk_position
        
        return base_condition & in_attending_region & causal
    
    chunked_target_buffer_mask.__name__ = f"chunked_target_buffer_{attending_chunks}"
    return chunked_target_buffer_mask


def diagonal_mask_except_context(context_len: int) -> _mask_mod_signature:
    """Diagonal self-attention mask for all positions except context.
    
    Args:
        context_len: Number of context tokens to exclude from diagonal
        
    Returns:
        Mask function that allows diagonal attention except in context region
    """
    def diagonal_mask(b, h, q_idx, kv_idx):
        # Only allow diagonal if both q and kv are NOT in context
        not_in_context = (q_idx >= context_len) & (kv_idx >= context_len)
        is_diagonal = q_idx == kv_idx
        return not_in_context & is_diagonal
    
    diagonal_mask.__name__ = f"diagonal_except_context_{context_len}"
    return diagonal_mask


def generate_bottleneck_mask_mod(
    context_len: int,
    latent_len: int, 
    buffer_len: int,
    attending_chunks: int = 4,
    latent_self_attention_type: str = "full"  # "full", "causal", or "none"
) -> _mask_mod_signature:
    """Generate comprehensive bottleneck mask combining all patterns.
    
    The attention pattern:
    - Context: No attention (processes through embedder only)
    - Latent: Cross-attend to context + self-attention
    - Buffer: Cross-attend to latent + causal self-attention
    - Target: Cross-attend to latent + chunked buffer attention + diagonal
    
    Args:
        context_len: Number of context tokens
        latent_len: Number of latent tokens
        buffer_len: Number of buffer tokens
        attending_chunks: Number of target chunks that attend to buffer
        latent_self_attention_type: Type of self-attention for latents
        
    Returns:
        Combined mask function for bottleneck attention pattern
    """
    # Create component masks
    masks = []
    
    # Latent patterns
    masks.append(latent_cross_attention_mask_factory(context_len, latent_len))
    
    if latent_self_attention_type != "none":
        causal = latent_self_attention_type == "causal"
        masks.append(latent_self_attention_mask_factory(context_len, latent_len, causal=causal))
    
    # Buffer patterns
    masks.append(buffer_to_latent_mask_factory(context_len, latent_len, buffer_len))
    masks.append(buffer_causal_mask_factory(context_len, latent_len, buffer_len))
    
    # Target patterns
    masks.append(target_to_latent_mask_factory(context_len, latent_len, buffer_len))
    masks.append(chunked_target_buffer_mask_factory(context_len, latent_len, buffer_len, attending_chunks))
    
    # Diagonal attention everywhere except context
    masks.append(diagonal_mask_except_context(context_len))
    
    # Combine all masks with OR
    final_mask = or_masks(*masks)
    
    final_mask.__name__ = (
        f"bottleneck_{context_len}_{latent_len}_{buffer_len}_{attending_chunks}_{latent_self_attention_type}"
    )
    return final_mask


def create_bottleneck_block_mask(
    current_total_q_len: int,
    current_total_kv_len: int,
    current_context_len: int,
    current_latent_len: int,
    current_buffer_len: int,
    attending_chunks: int = 4,
    latent_self_attention_type: str = "full",
    q_block_size: int = 128,
    kv_block_size: int = 128,
    device: str = "cuda",
) -> BlockMask:
    """Create BlockMask for bottleneck attention pattern.
    
    Args:
        current_total_q_len: Total query sequence length
        current_total_kv_len: Total key/value sequence length (same for self-attention)
        current_context_len: Number of context tokens
        current_latent_len: Number of latent tokens
        current_buffer_len: Number of buffer tokens
        attending_chunks: Number of target chunks that attend to buffer
        latent_self_attention_type: Type of self-attention for latents
        q_block_size: Block size for query chunking
        kv_block_size: Block size for key/value chunking
        device: Device to place mask tensors on
        
    Returns:
        BlockMask implementing the bottleneck attention pattern
    """
    bottleneck_mask_mod = generate_bottleneck_mask_mod(
        context_len=current_context_len,
        latent_len=current_latent_len,
        buffer_len=current_buffer_len,
        attending_chunks=attending_chunks,
        latent_self_attention_type=latent_self_attention_type,
    )
    
    bm = create_block_mask(
        bottleneck_mask_mod,
        Q_LEN=current_total_q_len,
        KV_LEN=current_total_kv_len,
        B=None,
        H=None,
        BLOCK_SIZE=(q_block_size, kv_block_size),
        device=device,
    )
    return bm


# Convenience function for visualization
def visualize_bottleneck_mask(
    context_len: int = 8,
    latent_len: int = 4, 
    buffer_len: int = 4,
    target_len: int = 16,
    attending_chunks: int = 2,
    latent_self_attention_type: str = "full"
):
    """Print a visual representation of the bottleneck mask pattern."""
    
    total_len = context_len + latent_len + buffer_len + target_len
    
    mask_mod = generate_bottleneck_mask_mod(
        context_len=context_len,
        latent_len=latent_len,
        buffer_len=buffer_len,
        attending_chunks=attending_chunks,
        latent_self_attention_type=latent_self_attention_type
    )
    
    print(f"\nBottleneck Mask Pattern:")
    print(f"  Context: 0-{context_len-1} (no attention)")
    print(f"  Latent: {context_len}-{context_len+latent_len-1} (attend to context + {latent_self_attention_type} self)")
    print(f"  Buffer: {context_len+latent_len}-{context_len+latent_len+buffer_len-1} (attend to latent + causal self)")
    print(f"  Target: {context_len+latent_len+buffer_len}-{total_len-1} (attend to latent + first {attending_chunks} chunks to buffer)")
    
    print("\nAttention pattern (✓ = can attend):")
    print("     " + "".join(f"{i:2d} " for i in range(total_len)))
    
    for q in range(total_len):
        row = f"Q{q:2d}: "
        for kv in range(total_len):
            can_attend = mask_mod(
                torch.tensor(0), torch.tensor(0), 
                torch.tensor(q), torch.tensor(kv)
            )
            row += " ✓ " if can_attend else " . "
        
        # Add section label
        if q < context_len:
            row += " (context - no attention)"
        elif q < context_len + latent_len:
            row += " (latent)"
        elif q < context_len + latent_len + buffer_len:
            row += " (buffer)"
        else:
            target_idx = q - context_len - latent_len - buffer_len
            chunk = target_idx // buffer_len
            row += f" (target chunk {chunk})"
        
        print(row)


if __name__ == "__main__":
    # Test visualization
    visualize_bottleneck_mask()