"""Block mask utilities for flex attention patterns."""

import math
from contextlib import nullcontext
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import (
    BlockMask,
    _mask_mod_signature,
    _score_mod_signature,
    create_block_mask,
    or_masks,
)

# Pre-compile create_block_mask for better performance (following nano-vllm pattern)
if torch.cuda.is_available() or hasattr(torch.version, 'hip'):
    create_block_mask = torch.compile(create_block_mask)

try:
    from torch.nn.attention.flex_attention import (
        TransformGetItemToIndex,
        _ModificationType,
        _vmap_for_bhqkv,
    )
except ImportError:
    _ModificationType = None
    _vmap_for_bhqkv = None
    TransformGetItemToIndex = None


def create_dense_kv_block_info(
    current_query_len: int,
    current_key_len: int,
    Q_BLOCK_SIZE: int,
    KV_BLOCK_SIZE: int,
    device: str,
):
    """Generate vectorized KV block info for dense N x M attention pattern.

    Creates block indices for a fully dense attention pattern where every query
    block can attend to every key block. This is used for self-attention and
    other dense attention patterns.

    Args:
        current_query_len: Length of the query sequence
        current_key_len: Length of the key/value sequence
        Q_BLOCK_SIZE: Size of query blocks for attention chunking
        KV_BLOCK_SIZE: Size of key/value blocks for attention chunking
        device: Device to place tensors on (e.g., 'cuda', 'cpu')

    Returns:
        Tuple containing:
            - kv_num_blocks: Tensor of shape (1, 1, num_q_blocks) indicating how many
              key blocks each query block attends to
            - kv_indices: Tensor of shape (1, 1, num_q_blocks, num_k_blocks) containing
              the indices of key blocks that each query block attends to
    """
    num_q_blocks = (current_query_len + Q_BLOCK_SIZE - 1) // Q_BLOCK_SIZE
    num_k_blocks = (current_key_len + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE

    # Each query block attends to all key blocks
    kv_num_blocks = torch.full(
        (1, 1, num_q_blocks), num_k_blocks, dtype=torch.long, device=device
    )

    # All key block indices [0, 1, ..., num_k_blocks-1]
    k_block_indices_row = torch.arange(num_k_blocks, device=device)
    kv_indices = k_block_indices_row.view(1, -1).expand(
        num_q_blocks, -1
    )  # Shape (num_q_blocks, num_k_blocks)
    kv_indices = kv_indices.view(
        1, 1, num_q_blocks, num_k_blocks
    )  # Reshape for BlockMask

    return kv_num_blocks, kv_indices


def chunked_target_buffer_mask_factory(
    context_len: int,
    buffer_len: int,
    attending_chunks: int = 4
) -> _mask_mod_signature:
    """First N chunks of targets attend causally to buffer.
    
    Creates a mask where targets are divided into chunks of buffer_len size.
    The first attending_chunks number of chunks can attend to buffer with
    a causal pattern, while remaining chunks have no buffer attention.
    
    Args:
        context_len: Number of context tokens
        buffer_len: Number of buffer tokens (also chunk size)
        attending_chunks: Number of target chunks that attend to buffer
        
    Returns:
        Mask function for chunked target-buffer attention pattern
    """
    def chunked_target_buffer_mask(b, h, q_idx, kv_idx):
        buffer_start = context_len
        target_start = context_len + buffer_len
        
        q_in_target = q_idx >= target_start
        kv_in_buffer = (kv_idx >= buffer_start) & (kv_idx < buffer_start + buffer_len)
        
        # Basic condition: query in target and key in buffer
        base_condition = q_in_target & kv_in_buffer
        
        target_offset = q_idx - target_start
        buffer_offset = kv_idx - buffer_start
        
        # First attending_chunks * buffer_len positions attend
        in_attending_region = target_offset < (attending_chunks * buffer_len)
        
        # Causal within chunk
        chunk_position = target_offset % buffer_len
        causal = buffer_offset <= chunk_position
        
        return base_condition & in_attending_region & causal
    
    chunked_target_buffer_mask.__name__ = f"chunked_{context_len}_{buffer_len}_{attending_chunks}"
    return chunked_target_buffer_mask


def diag_mask_runtime(b, h, q_idx, kv_idx):
    """Runtime diagonal mask allowing only self-attention (q_idx == kv_idx).

    Args:
        b: Batch index (unused)
        h: Head index (unused)
        q_idx: Query position index
        kv_idx: Key/Value position index

    Returns:
        Boolean tensor indicating whether attention is allowed
        (True for diagonal positions)
    """
    return q_idx == kv_idx


diag_mask_runtime.__name__ = "diag_mask_rt"


def generate_training_mask_mod_runtime(
    current_context_len: int, 
    current_buffer_len: int, 
    attending_chunks: int = 4
) -> _mask_mod_signature:
    """Generate comprehensive training mask combining multiple attention patterns.

    Creates a composite mask function that combines several attention patterns
    commonly used during training:
    1. Prefix attention (all tokens can attend to context)
    2. Localized causal attention for context and buffer regions
    3. Diagonal attention (self-attention)
    4. Chunked target-buffer attention (first N chunks attend causally to buffer)

    Args:
        current_context_len: Length of the context section
        current_buffer_len: Length of the buffer section
        attending_chunks: Number of target chunks that attend to buffer

    Returns:
        A mask_mod function that implements the combined training attention pattern
        by taking the logical OR of all component patterns.
    """
    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    def prefix_mask(b, h, q_idx, kv_idx):
        return kv_idx < current_context_len

    def localized_causal_ctx_buf(b, h, q_idx, kv_idx):
        ctx_buf_end = current_context_len + current_buffer_len
        q_in_region = q_idx < ctx_buf_end
        kv_in_region = kv_idx < ctx_buf_end
        return q_in_region & kv_in_region & causal_mask(b, h, q_idx, kv_idx)

    # Get the chunked target-buffer mask
    chunked_target_buffer = chunked_target_buffer_mask_factory(
        current_context_len, current_buffer_len, attending_chunks
    )

    # Combine all masks with OR
    final_mask_mod = or_masks(
        prefix_mask,
        localized_causal_ctx_buf,
        diag_mask_runtime,
        chunked_target_buffer,
    )
    
    final_mask_mod.__name__ = (
        f"training_mask_{current_context_len}_{current_buffer_len}_{attending_chunks}"
    )
    return final_mask_mod


def create_context_self_attention_block_mask(
    current_num_context: int, q_block_size: int, kv_block_size: int, device: str
) -> BlockMask:
    """Create BlockMask for context self-attention with fully dense pattern.

    Constructs a block mask that allows full self-attention within the context region.
    Every context token can attend to every other context token, making this suitable
    for encoding context information where all relationships should be modeled.

    Args:
        current_num_context: Number of context tokens
        q_block_size: Block size for query chunking
        kv_block_size: Block size for key/value chunking
        device: Device to place mask tensors on

    Returns:
        BlockMask configured for dense self-attention over context tokens
    """
    q_len = current_num_context
    kv_len = current_num_context

    kv_num_blocks, kv_indices = create_dense_kv_block_info(
        current_query_len=q_len,
        current_key_len=kv_len,
        Q_BLOCK_SIZE=q_block_size,
        KV_BLOCK_SIZE=kv_block_size,
        device=device,
    )

    def _dense_allow_all_mask_mod(b, h, q, kv):
        return torch.ones_like(q, dtype=torch.bool)

    _dense_allow_all_mask_mod.__name__ = f"dense_self_attn_ctx{current_num_context}"

    return BlockMask.from_kv_blocks(
        kv_num_blocks=kv_num_blocks,
        kv_indices=kv_indices,
        BLOCK_SIZE=(q_block_size, kv_block_size),
        mask_mod=_dense_allow_all_mask_mod,
        seq_lengths=(q_len, kv_len),
    )


def create_training_block_mask(
    current_total_q_len: int,
    current_total_kv_len: int,
    current_context_section_len: int,
    current_buffer_section_len: int,
    attending_chunks: int = 4,
    q_block_size: int = 128,
    kv_block_size: int = 128,
    device: str = "cuda",
) -> BlockMask:
    """Create BlockMask for ACEv2 training with composite attention patterns.

    Constructs a comprehensive training mask that combines multiple attention patterns
    needed for training the Amortized Conditioning Engine. The mask handles:
    - Context tokens that can attend to all previous context
    - Buffer tokens with localized causal attention
    - Target tokens with chunked buffer attention pattern

    This function uses create_block_mask for simplicity, though it may be slower
    to initialize than manually constructing kv_blocks. The masks can be cached
    and indexed based on current sequence lengths for efficiency.

    Args:
        current_total_q_len: Total length of query sequence (context + buffer + targets)
        current_total_kv_len: Total length of key/value sequence
            (same as query for self-attention)
        current_context_section_len: Length of context section for mask logic
        current_buffer_section_len: Length of buffer section for mask logic
        attending_chunks: Number of target chunks that attend to buffer
        q_block_size: Block size for query chunking
        kv_block_size: Block size for key/value chunking
        device: Device to place mask tensors on

    Returns:
        BlockMask implementing the composite training attention pattern
    """
    runtime_training_mask_mod = generate_training_mask_mod_runtime(
        current_context_len=current_context_section_len,
        current_buffer_len=current_buffer_section_len,
        attending_chunks=attending_chunks,
    )

    bm = create_block_mask(
        runtime_training_mask_mod,
        Q_LEN=current_total_q_len,
        KV_LEN=current_total_kv_len,
        B=None,
        H=None,
        BLOCK_SIZE=(q_block_size, kv_block_size),
        device=device,
    )
    return bm


def create_score_mod(
    query: torch.Tensor,
    key: torch.Tensor,
    score_mod: Optional[_score_mod_signature],
    mask_mod: Optional[_mask_mod_signature],
    device: str = "cuda",
    _compile: bool = False,
    scale: Optional[float] = None,
    batch_idx: int = 0,
    head_idx: int = 0,
) -> torch.Tensor:
    """Create attention scores for visualization using score_mod or mask_mod.

    Computes attention scores between query and key tensors, applying the provided
    score_mod or mask_mod function for visualization purposes.

    Args:
        query: Query tensor of shape (seq_len, head_dim)
        key: Key tensor of shape (seq_len, head_dim)
        score_mod: Optional score modification function
        mask_mod: Optional mask modification function
        device: Device for computation
        _compile: Whether to compile the modification function
        scale: Scale factor for attention scores (default: 1/sqrt(head_dim))
        batch_idx: Batch index offset for modification functions
        head_idx: Head index offset for modification functions

    Returns:
        Tensor of attention scores with modifications applied
    """
    B = 1
    H = 1
    M = query.shape[0]
    N = key.shape[0]

    b = torch.arange(0, B, device=device) + batch_idx
    h = torch.arange(0, H, device=device) + head_idx
    m = torch.arange(0, M, device=device)
    n = torch.arange(0, N, device=device)

    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    type = (
        _ModificationType.SCORE_MOD
        if score_mod is not None
        else _ModificationType.MASK_MOD
    )
    if _compile:
        ctx = nullcontext()
    else:
        ctx = TransformGetItemToIndex()

    with ctx:
        mod_fn = score_mod if type == _ModificationType.SCORE_MOD else mask_mod
        prefix = (0,) if type == _ModificationType.SCORE_MOD else ()
        mod = _vmap_for_bhqkv(mod_fn, prefix=prefix)
        scores = query @ key.transpose(-2, -1)
        scores *= scale_factor
        scores = scores.view(1, 1, M, N)
        if type == _ModificationType.SCORE_MOD:
            out = mod(scores, b, h, m, n)
        else:
            out = mod(b, h, m, n)

    return out


def _name_to_title(name: str) -> str:
    """Convert underscore-separated name to title case.

    Args:
        name: String with underscores to convert

    Returns:
        Title case string with spaces instead of underscores
    """
    title = name.replace("_", " ")
    title = " ".join(word.capitalize() for word in title.split())
    return title


def visualize_attention_scores(
    query: Tensor,
    key: Tensor,
    score_mod: Optional[_score_mod_signature] = None,
    mask_mod: Optional[_mask_mod_signature] = None,
    device: str = "cuda",
    name: str = "attention_scores",
    path: Optional[Path] = None,
    batch_idx: int = 0,
    head_idx: int = 0,
    scale: Optional[float] = None,
):
    """Generate and save a visualization of attention scores.

    Creates a heatmap visualization of attention scores computed between query and key
    tensors, with optional score_mod or mask_mod modifications applied.

    Args:
        query: Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim)
        key: Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim)
        score_mod: Score modification function (takes precedence over mask_mod)
        mask_mod: Mask modification function
        device: Device for computation
        name: Base name for file and title
        path: Save path (default: current directory)
        batch_idx: Batch index to visualize
        head_idx: Head index to visualize
        scale: Scale factor for attention scores

    Returns:
        None (saves plot to file)
    """

    assert score_mod is not None or mask_mod is not None, (
        "Must provide either score_mod or mask_mod"
    )
    query = query[batch_idx, head_idx, :, :]
    key = key[batch_idx, head_idx, :, :]
    scores_viz = create_score_mod(
        query,
        key,
        score_mod=score_mod,
        mask_mod=mask_mod,
        scale=scale,
        device=device,
        batch_idx=batch_idx,
        head_idx=head_idx,
    )
    # If both score_mod and mask_mod are provided, apply both
    if score_mod is not None and mask_mod is not None:
        mask_viz = create_score_mod(
            query,
            key,
            score_mod=None,
            mask_mod=mask_mod,
            scale=scale,
            device=device,
            batch_idx=batch_idx,
            head_idx=head_idx,
        )
        # Apply mask by setting masked positions to -inf
        scores_viz = torch.where(mask_viz == 0, float("-inf"), scores_viz)

    suffix_title = (
        f"Batch {batch_idx}, Head {head_idx}" if batch_idx != 0 or head_idx != 0 else ""
    )

    fig, ax = plt.subplots(figsize=(12, 10))
    color = "viridis" if score_mod is not None else "cividis"
    if score_mod is not None and mask_mod is not None:
        color = "plasma"
    im = ax.imshow(scores_viz.cpu().detach()[0, 0, :, :], aspect="auto", cmap=color)
    fig.colorbar(im)

    title = _name_to_title(name)
    file_path = (
        Path(name).with_suffix(".png") if path is None else path.with_suffix(".png")
    )
    ax.set_title(f"{title}\n{suffix_title}", fontsize=20)

    ax.set_xlabel("Key Tokens", fontsize=18)
    ax.set_ylabel("Query Tokens", fontsize=18)

    # Move y-axis ticks and labels to the top
    ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False)

    # Add tick labels if the number of tokens is manageable
    num_query_tokens, num_kv_tokens = scores_viz.shape[-2:]
    if num_query_tokens <= 32 and num_kv_tokens <= 32:
        ax.set_xticks(range(num_kv_tokens))
        rotation = 45 if num_kv_tokens > 12 else 0
        ax.set_xticklabels(
            [f"KV{i}" for i in range(num_kv_tokens)], fontsize=16, rotation=rotation
        )
        ax.set_yticks(range(num_query_tokens))
        ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)], fontsize=16)
        # Align grid with pixel boundaries
        ax.set_xticks(np.arange(-0.5, num_kv_tokens, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, num_query_tokens, 1), minor=True)
        ax.grid(which="minor", color="black", linestyle="-", linewidth=2)

    plt.tight_layout()
    plt.savefig(file_path, dpi=300, bbox_inches="tight")
    plt.close()
