import torch
from torch import nn
from typing import Optional, Tuple, Callable, Dict, Any
from transformers.modeling_utils import AttentionInterface
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb

import logging

logger = logging.getLogger(__name__)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def convert_attention_mask(
    attention_mask: Optional[torch.Tensor],
    bs: int,
    num_heads: int,
    seq_len: int,
    kv_len: int,
    device: torch.device,
) -> torch.Tensor:

    if attention_mask is None:
        extra_keys = kv_len - seq_len

        base = torch.tril(
            torch.ones(seq_len, kv_len, device=device, dtype=torch.bool),
            diagonal=extra_keys,
        )
        allowed_mask = base.view(1, 1, seq_len, kv_len)
    else:
        m = attention_mask[:, :, :, :kv_len]
        if m.dtype == torch.bool:
            allowed_mask = m
        elif torch.is_floating_point(m):

            allowed_mask = m >= 0
        else:

            allowed_mask = m != 0

    allowed_mask = allowed_mask.expand(bs, num_heads, seq_len, kv_len)
    return allowed_mask


def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    **kwargs,
) -> Tuple[torch.Tensor, None]:
    if hasattr(module, "num_key_value_groups"):
        key = repeat_kv(key, module.num_key_value_groups)
        value = repeat_kv(value, module.num_key_value_groups)

    causal_mask = attention_mask
    if attention_mask is not None and causal_mask.ndim == 4:
        causal_mask = causal_mask[:, :, :, : key.shape[-2]]

    query = query.contiguous()
    key = key.contiguous()
    value = value.contiguous()

    if is_causal is None:
        is_causal = query.shape[2] > 1 and causal_mask is None

    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=causal_mask,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    upcast_attn_matmul: bool = True,
    **kwargs,
):
    causal_mask = convert_attention_mask(
        attention_mask,
        query.shape[0],
        query.shape[1],
        query.shape[2],
        key.shape[2],
        query.device,
    )

    if upcast_attn_matmul:
        attn_matmul_dtype = torch.float32
    else:
        attn_matmul_dtype = query.dtype
    neg_inf = torch.tensor(float("-inf"), device=query.device, dtype=attn_matmul_dtype)
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    queries_for_matmul = query.to(attn_matmul_dtype).contiguous()
    keys_for_matmul = key_states.to(attn_matmul_dtype).transpose(2, 3).contiguous()
    value_states = value_states.contiguous()

    attn_weights = torch.matmul(queries_for_matmul, keys_for_matmul) * scaling
    attn_weights = attn_weights.to(torch.float32).masked_fill_(~causal_mask, neg_inf)

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
        query.dtype
    )
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


def sparsified_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    maximum_keys: int,
    recency_fraction: float = 0.5,
    dropout: float = 0.0,
    upcast_attn_matmul: bool = True,
    consistent_kv: bool = False,
    return_attn_mask: bool = False,
    **kwargs,
):
    if upcast_attn_matmul:
        attn_matmul_dtype = torch.float32
    else:
        attn_matmul_dtype = query.dtype

    neg_inf = torch.tensor(float("-inf"), device=query.device, dtype=attn_matmul_dtype)
    bs, num_heads, seq_len, head_dim = query.shape
    key_len = key.shape[-2]

    if maximum_keys <= 0:
        raise ValueError("maximum_keys must be >= 1")
    assert (
        recency_fraction >= 0.0 and recency_fraction <= 1.0
    ), "recency_fraction must be in [0.0, 1.0]"
    kv_len = key.shape[-2]
    max_keys = min(int(maximum_keys), kv_len)

    num_recent_keys = int(max_keys * recency_fraction)
    num_top_keys = max_keys - num_recent_keys

    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    queries_for_matmul = query.to(attn_matmul_dtype).contiguous()
    keys_for_matmul = key_states.to(attn_matmul_dtype).transpose(2, 3).contiguous()
    value_states = value_states.contiguous()

    attn_weights = torch.matmul(queries_for_matmul, keys_for_matmul) * scaling

    causal_mask = convert_attention_mask(
        attention_mask, bs, num_heads, seq_len, key_len, query.device
    )
    attn_weights = attn_weights.masked_fill_(~causal_mask, neg_inf)

    recent_counts = torch.cumsum(causal_mask.to(torch.int32).flip(-1), dim=-1).flip(-1)

    recent_mask = recent_counts <= num_recent_keys

    if consistent_kv:
        remaining_kvs = kv_len - seq_len
        assert remaining_kvs >= 0, "The number of KVs is less than the sequence length"
        self_mask = torch.eye(seq_len, dtype=torch.bool, device=query.device)
        if remaining_kvs > 0:
            other_mask = torch.zeros(
                (seq_len, remaining_kvs), dtype=torch.bool, device=query.device
            )
            self_mask = torch.cat([other_mask, self_mask], dim=-1)
        keep_mask = torch.logical_or(
            self_mask[None, None, :, :], recent_mask[..., [-1], :]
        )
    else:
        keep_mask = recent_mask.clone()

    if num_top_keys > 0:
        if consistent_kv:
            non_recent_mask = torch.logical_and(
                causal_mask[..., [-1], :], (~recent_mask[..., [-1], :])
            )
            non_recent_scores = torch.where(
                non_recent_mask, attn_weights[..., [-1], :], neg_inf
            )
        else:
            non_recent_mask = torch.logical_and(causal_mask, (~recent_mask))
            non_recent_scores = torch.where(non_recent_mask, attn_weights, neg_inf)
        k = min(num_top_keys, kv_len)
        top_idx = torch.topk(non_recent_scores, k=k, dim=-1).indices
        if consistent_kv:
            top_idx = top_idx.expand(bs, num_heads, seq_len, k)
        keep_mask.scatter_(-1, top_idx, True)

    keep_mask = torch.logical_and(keep_mask, causal_mask)
    attn_weights = attn_weights.masked_fill(~keep_mask, neg_inf)

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
        query.dtype
    )
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    if return_attn_mask:
        return attn_output, attn_weights, keep_mask
    return attn_output, attn_weights


def sparsified_random_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    **kwargs,
) -> Tuple[torch.Tensor, None]:

    if hasattr(module, "num_key_value_groups"):
        key = repeat_kv(key, module.num_key_value_groups)
        value = repeat_kv(value, module.num_key_value_groups)

    causal_mask = attention_mask
    if attention_mask is not None and causal_mask.ndim == 4:
        causal_mask = causal_mask[:, :, :, : key.shape[-2]]

    query = query.contiguous()
    key = key.contiguous()
    value = value.contiguous()

    if is_causal is None:
        is_causal = query.shape[2] > 1 and causal_mask is None

    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=causal_mask,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None


class ConsistentSparsifiedAttentionWithMemory(nn.Module):
    def __init__(self, set_attention_mask_at_layer: Optional[int] = None):
        nn.Module.__init__(self)
        self.set_attention_mask_at_layer = set_attention_mask_at_layer
        self.cached_mask = None

    def forward(self, *args, layer_idx: int, **kwargs):
        if layer_idx <= self.set_attention_mask_at_layer:
            output = sparsified_attention_forward(
                *args, **kwargs, return_attn_mask=True
            )
            self.cached_mask = output[2]
            return output[:2]
        else:
            if self.cached_mask is None:
                raise ValueError(
                    "cached_mask is None, but layer_idx > "
                    "set_attention_mask_at_layer"
                )
            args = list(args)
            args[-1] = self.cached_mask
            output = sdpa_attention_forward(*args, **kwargs)
            return output


ALL_NOPE_EXTENDED_ATTENTION_FUNCTIONS: AttentionInterface = {
    "sparsified": sparsified_attention_forward,
    "eager": eager_attention_forward,
}


def apply_no_pos_emb(
    q,
    k,
    scaling,
    attention_scaling_coef=1.0,
    **kwargs,
):

    scaling = scaling * attention_scaling_coef
    return q, k, scaling


def qwen2_custom_nope_attn(
    module,
    hidden_states: torch.Tensor,
    position_embeddings: Tuple[torch.Tensor, torch.Tensor],
    attention_mask: Optional[torch.Tensor],
    use_extended_attention_fn: bool,
    context_extension_attn_impl: Optional[str] = None,
    context_extension_attn_params: Dict[str, Any] = {},
    context_extension_nope_params: Dict[str, Any] = {},
    apply_rotary_position_embeddings: bool = False,
    past_key_value: Optional[Cache] = None,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs: Any,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, module.head_dim)

    query_states = module.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    key_states = module.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    value_states = module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    cos, sin = position_embeddings

    if apply_rotary_position_embeddings:
        assert len(context_extension_nope_params) == 0, (
            "Rotary embeddings and NOPE are not meant to be applied " "simultaneously"
        )
        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin
        )
    else:
        query_states, key_states, scaling = apply_no_pos_emb(
            query_states, key_states, module.scaling, **context_extension_nope_params
        )

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(
            key_states, value_states, module.layer_idx, cache_kwargs
        )

    sliding_window = None
    if (
        module.config.use_sliding_window
        and getattr(module.config, "sliding_window", None) is not None
        and module.layer_idx >= module.config.max_window_layers
    ):
        sliding_window = module.config.sliding_window

    attention_interface: Callable = eager_attention_forward

    if use_extended_attention_fn:
        if isinstance(context_extension_attn_impl, str):
            attention_interface = ALL_NOPE_EXTENDED_ATTENTION_FUNCTIONS[
                context_extension_attn_impl
            ]
        elif isinstance(context_extension_attn_impl, list):
            attention_interface = ALL_NOPE_EXTENDED_ATTENTION_FUNCTIONS[
                context_extension_attn_impl[module.layer_idx]
            ]
        elif isinstance(context_extension_attn_impl, Callable):
            attention_interface = context_extension_attn_impl
        else:
            raise ValueError(
                "context_extension_attn_impl must be a string, list of strings "
                "or a callable"
            )

    elif module.config._attn_implementation != "eager":
        if module.config._attn_implementation == "sdpa" and kwargs.get(
            "output_attentions", False
        ):
            logger.warning_once(
                "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
        else:
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                module.config._attn_implementation
            ]

    kwargs["layer_idx"] = module.layer_idx

    attn_output, attn_weights = attention_interface(
        module,
        query_states,
        key_states,
        value_states,
        attention_mask,
        dropout=0.0 if not module.training else module.attention_dropout,
        scaling=scaling,
        sliding_window=sliding_window,
        **kwargs,
        **context_extension_attn_params,
    )

    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    attn_output = module.o_proj(attn_output)
    return attn_output, attn_weights
