from typing import Optional, Tuple

import torch

from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
if is_flash_attn_2_available():
    from flash_attn import flash_attn_with_kvcache
    #flash_attn_with_kvcache = torch._dynamo.disable(flash_attn_with_kvcache)

_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def _flash_attn_with_kvcache_forward(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    attention_mask: torch.Tensor,
    query_length: int,
    is_causal: bool,
    dropout: float = 0.0,
    position_ids: Optional[torch.Tensor] = None,
    softmax_scale: Optional[float] = None,
    sliding_window: Optional[int] = None,
    use_top_left_mask: bool = False,
    softcap: Optional[float] = None,
    deterministic: bool = None,
    cu_seq_lens_q: Optional[torch.LongTensor] = None,
    cu_seq_lens_k: Optional[torch.LongTensor] = None,
    max_length_q: Optional[int] = None,
    max_length_k: Optional[int] = None,
    target_dtype: Optional[torch.dtype] = None,
    **kwargs,
):
    layer_idx = kwargs['layer_idx']
    k_cache = kwargs['k_cache'][layer_idx]
    v_cache = kwargs['v_cache'][layer_idx]
    k_cache = k_cache.transpose(1, 2)
    v_cache = v_cache.transpose(1, 2) # TODO MAYBE sync after this?!

    q = query_states
    k = key_states
    v = value_states
    cache_seqlens = kwargs['cache_seqlens']
    assert is_causal

    #print(cache_seqlens)

    #print('attention_mask', attention_mask)

    try:
        attn_output = flash_attn_with_kvcache(
            q,
            k_cache,
            v_cache,
            k=k,
            v=v,
            rotary_cos=None,
            rotary_sin=None,
            cache_seqlens=cache_seqlens,
            cache_batch_idx=None,
            block_table=None,
            softmax_scale=None,
            causal=True,
            window_size=(-1, -1),  # -1 means infinite context window
            rotary_interleaved=True,
            alibi_slopes=None
        )
    except RuntimeError as e:
        print(e)
        print('k_cache', k_cache.shape)
        print('v_cache', v_cache.shape)
        print('q', q.shape)
        print('k', k.shape)
        print('v', v.shape)
        print('cache_seqlens', cache_seqlens)
        torch.save(k_cache, 'k_cache.pt')
        torch.save(k_cache, 'v_cache.pt')
        torch.save(k_cache, 'q.pt')
        torch.save(k_cache, 'k.pt')
        torch.save(k_cache, 'v.pt')
        torch.save(k_cache, 'cache_seqlens.pt')
        print('saved tensors')
        exit()
    return attn_output


def flash_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    sliding_window: Optional[int] = None,
    softcap: Optional[float] = None,
    **kwargs,
) -> Tuple[torch.Tensor, None]:

    # This is before the transpose
    seq_len = query.shape[2]

    # FA2 uses non-transposed inputs
    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)

    # In PEFT, usually we cast the layer norms in float32 for training stability reasons
    # therefore the input hidden states gets silently casted in float32. Hence, we need
    # cast them back in the correct dtype just to be sure everything works as expected.
    # This might slowdown training & inference so it is recommended to not cast the LayerNorms
    # in fp32. (usually our RMSNorm modules handle it correctly)
    target_dtype = None
    if query.dtype == torch.float32:
        if torch.is_autocast_enabled():
            target_dtype = torch.get_autocast_gpu_dtype()
        # Handle the case where the model is quantized
        elif hasattr(module.config, "_pre_quantization_dtype"):
            target_dtype = module.config._pre_quantization_dtype
        else:
            target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype

    # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
    kwargs.pop("is_causal", None)

    is_inference = kwargs.pop("is_inference", False)
    assert module.is_causal is True

    if is_inference:
        attn_output = _flash_attn_with_kvcache_forward(
            query,
            key,
            value,
            attention_mask,
            query_length=seq_len,
            is_causal=module.is_causal,
            dropout=dropout,
            softmax_scale=scaling,
            sliding_window=sliding_window,
            softcap=softcap,
            use_top_left_mask=_use_top_left_mask,
            target_dtype=target_dtype,
            **kwargs,
        )
    else:
        attn_output = _flash_attention_forward(
            query,
            key,
            value,
            attention_mask,
            query_length=seq_len,
            is_causal=module.is_causal,
            dropout=dropout,
            softmax_scale=scaling,
            sliding_window=sliding_window,
            softcap=softcap,
            use_top_left_mask=_use_top_left_mask,
            target_dtype=target_dtype,
            **kwargs,
        )

    return attn_output, None
