from typing import Callable, Optional, Tuple, Union, Dict, Any
from functools import lru_cache, partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.attention.flex_attention import (
    create_block_mask,
    _mask_mod_signature,
    BlockMask,
    flex_attention,
    _score_mod_signature,
)


# flex_attention_compiled = torch.compile(flex_attention, mode="max-autotune-no-cudagraphs", dynamic=False)
flex_attention_compiled = torch.compile(flex_attention, dynamic=True)
# flex_attention_compiled = torch.compile(flex_attention)
# flex_attention_compiled = flex_attention

kernel_options = {}

# check CUDA type, if it is 'L40S', then add kernel options for L40S
if "L40S" in torch.cuda.get_device_properties(0).name:
    kernel_options = {
        "BLOCK_M": 64,
        "BLOCK_N": 64,
        "BLOCK_M1": 32,
        "BLOCK_N1": 64,
        "BLOCK_M2": 64,
        "BLOCK_N2": 32,
    }
elif "H100" in torch.cuda.get_device_properties(0).name or "H200" in torch.cuda.get_device_properties(0).name:
    kernel_options = {
        "BLOCK_M": 32,
        "BLOCK_N": 32,
        "num_stages": 2,
        "FORCE_USE_FLEX_ATTENTION": True,  # TODO inspect flex_decode
    }
elif "RTX 5000" in torch.cuda.get_device_properties(0).name:
    kernel_options = {
        "BLOCK_M": 64,
        "BLOCK_N": 64,
        "BLOCK_M1": 32,
        "BLOCK_N1": 64,
        "BLOCK_M2": 64,
        "BLOCK_N2": 32,
    }
elif "A40" in torch.cuda.get_device_properties(0).name:
    kernel_options = {
        "BLOCK_M": 64,
        "BLOCK_N": 64,
        "BLOCK_M1": 32,
        "BLOCK_N1": 64,
        "BLOCK_M2": 64,
        "BLOCK_N2": 32,
    }

def is_power_of_two(n: int) -> bool:
    return (n != 0) and (n & (n - 1)) == 0

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", **kwargs):
    block_mask = create_block_mask(score_mod, B=B, H=H, Q_LEN=M, KV_LEN=N, device=device, **kwargs)
    return block_mask


def find_multiple(n: int, k: int) -> int:
    if n % k == 0:
        return n
    return n + k - (n % k)


def get_mask_mod(mask_mod: _mask_mod_signature, offset: int):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + offset, kv)

    return _mask_mod


def attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: BlockMask,
    scaling: float,
    **kwargs,
):
    _, N_Q_HEADS, _, _ = query.shape
    _, N_KV_HEADS, _, _ = key.shape
    N_Q_PER_GROUP = N_Q_HEADS // N_KV_HEADS
    enable_gqa = (N_Q_HEADS != N_KV_HEADS) and is_power_of_two(N_Q_PER_GROUP)
    if not enable_gqa and N_Q_PER_GROUP != 1:
        key = repeat_kv(key, N_Q_PER_GROUP)
        value = repeat_kv(value, N_Q_PER_GROUP)

    attn_output = flex_attention_compiled(
        query,
        key,
        value,
        scale=scaling,
        block_mask=attention_mask,
        enable_gqa=enable_gqa,
        kernel_options=kernel_options,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None, None


def rnsa_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: BlockMask,
    forget_weights: torch.Tensor,
    kv_positions: torch.Tensor = None,
    offset: int = 0,
    scaling: float = None,
    fg_dropout: float = 0.0,
    **kwargs,
):
    _, N_Q_HEADS, M, _ = query.shape
    _, N_KV_HEADS, N, _ = key.shape
    N_Q_PER_GROUP = N_Q_HEADS // N_KV_HEADS
    enable_gqa = (N_Q_HEADS != N_KV_HEADS) and is_power_of_two(N_Q_PER_GROUP)
    if not enable_gqa and N_Q_PER_GROUP != 1:
        key = repeat_kv(key, N_Q_PER_GROUP)
        value = repeat_kv(value, N_Q_PER_GROUP)

    forget_weights = forget_weights.to(torch.float32)
    if fg_dropout > 0.0 and module.training:
        forget_weights = F.dropout(forget_weights, p=fg_dropout, training=module.training)

    # summarized_forget_weights = summarize_forget_weights(forget_weights)

    def score_mod_w_kv_pos(score, b, h, q_idx, kv_idx):
        # return score * (forget_weights[b, h // N_Q_PER_GROUP, kv_idx] ** (q_idx + Q_KV_DIFF - kv_idx))
        # return score * (torch.exp((q_idx + Q_KV_DIFF - kv_idx) * torch.log(forget_weights[b, h // N_Q_PER_GROUP, kv_idx])))
        # avoid numerical issues by clamping the exponent
        return score * torch.exp((forget_weights[b, h // N_Q_PER_GROUP, kv_idx] * (q_idx + offset - kv_positions[b, h // N_Q_PER_GROUP, kv_idx])).clamp(min=-70, max=70))
        # return score * torch.exp(((q_idx + offset - kv_idx) * forget_weights[b, h // N_Q_PER_GROUP, kv_idx]).clamp(min=-70, max=70))

    def score_mod_wo_kv_pos(score, b, h, q_idx, kv_idx):
        return score * torch.exp((forget_weights[b, h // N_Q_PER_GROUP, kv_idx] * (q_idx + offset - kv_idx)).clamp(min=-70, max=70))


    attn_output = flex_attention_compiled(
        query,
        key,
        value,
        scale=scaling,
        block_mask=attention_mask,
        score_mod=score_mod_w_kv_pos if kv_positions is not None else score_mod_wo_kv_pos,
        enable_gqa=enable_gqa,
        kernel_options=kernel_options,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    # return attn_output, None, summarized_forget_weights
    return attn_output, None, None


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    is_causal: bool = False,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = key
    value_states = value

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask
    elif is_causal:
        min_dtype = torch.finfo(query.dtype).min
        batch_size, num_heads, seq_len, _ = query.shape
        target_len = key_states.shape[-2]
        causal_mask = torch.full(
            (seq_len, target_len), device=query.device, fill_value=min_dtype
        )
        diagonal_attend_mask = torch.arange(target_len, device=query.device) > torch.arange(target_len - seq_len, target_len, device=query.device).view(-1, 1)
        causal_mask *= diagonal_attend_mask
        causal_mask = causal_mask.view(1, 1, seq_len, target_len).expand(batch_size, num_heads, seq_len, target_len)
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output


__all__ = [
    "attention_forward",
    "eager_attention_forward",
    "rnsa_attention_forward",

]
