import math
import torch
import torch.nn.functional as F
from einops import rearrange

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention_image_top_k(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Set the current instance to use the parameters within the custom function
    kwargs = scaled_dot_product_attention_image_top_k.kwargs
    image_top_k = kwargs['image_top_k_attention']
    image_regions = kwargs['image_regions']
    num_images = kwargs['num_images']

    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_bias + attn_mask
            # attn_bias += attn_mask

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias 
    
    # Apply Top-k adjustment for image tokens only when L == 1 (Decoding)
    if L == 1 and image_regions is not None:
        # Extract attention weights for all regions simultaneously
        region_weights = attn_weight[..., image_regions]  # Shape: (1, Head, 1, num_regions, region_length)
        region_weights_reshaped = rearrange(region_weights, 'b h 1 1 (n l) -> b h 1 n l', n=num_images)

        # Perform top-k filtering across all regions in parallel
        region_length = region_weights_reshaped.size(-1)
        if region_length >= image_top_k:
            top_k_values, top_k_indices = torch.topk(region_weights_reshaped, image_top_k, dim=-1, largest=True)
        else:
            raise ValueError(f"Region length {region_length} is less than top-k value {image_top_k}")
            # top_k_values = region_weights
            # top_k_indices = torch.arange(region_length, device=query.device).view(1, 1, 1, 1, region_length).expand(region_weights.shape)

        # Create a mask for non-top-k elements and set them to -inf
        mask = torch.full_like(region_weights_reshaped, float('-inf'))
        mask.scatter_(-1, top_k_indices, top_k_values)

        # Apply the mask to keep only top-k values
        attn_weight[..., image_regions] = rearrange(mask, 'b h 1 n l -> b h 1 1 (n l)')


    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

"""
original function
"""
# def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
#     # Efficient implementation equivalent to the following:
#     L, S = query.size(-2), key.size(-2)
#     scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
#     attn_bias = torch.zeros(L, S, dtype=query.dtype)
#     if is_causal:
#         assert attn_mask is None
#         temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
#         attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
#         attn_bias.to(query.dtype)

#     if attn_mask is not None:
#         if attn_mask.dtype == torch.bool:
#             attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
#         else:
#             attn_bias += attn_mask
#     attn_weight = query @ key.transpose(-2, -1) * scale_factor
#     attn_weight += attn_bias
#     attn_weight = torch.softmax(attn_weight, dim=-1)
#     attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
#     return attn_weight @ value