import torch
from transformers import DynamicCache

from .....kernels import wait_for_ACT
from .....utils import is_flash_attention_available
from ....enums import PositionEmbeddingType
from ...position_embedding import apply_rotary_pos_emb
from .base import Attention


if is_flash_attention_available():
    from flash_attn.flash_attn_interface import flash_attn_varlen_func


class PaddingFreeAttention(Attention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        past_key_values: DynamicCache | None = None,
        attention_mask: torch.Tensor | None = None,
        rope_cos_sin: torch.Tensor | None = None,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,
    ) -> torch.Tensor:
        assert past_key_values is None

        # ==========================================================================================
        # hidden_states -> (total_q, num_heads * head_dim)
        # ==========================================================================================

        query, key, value = self._prepare_qkv_for_forward(hidden_states)

        # ==========================================================================================
        # query -> (total_q, num_heads, head_dim)
        # key -> (total_q, num_key_value_heads, head_dim)
        # value -> (total_q, num_key_value_heads, head_dim)
        # ==========================================================================================

        if self.position_embedding_type == PositionEmbeddingType.rope:
            query = apply_rotary_pos_emb(query, rope_cos_sin)
            key = apply_rotary_pos_emb(key, rope_cos_sin)

        # ==========================================================================================
        # query -> (total_q, num_heads, head_dim)
        # key -> (total_q, num_key_value_heads, head_dim)
        # value -> (total_q, num_key_value_heads, head_dim)
        # ==========================================================================================

        query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False)
        key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False)
        value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False)

        hidden_states = flash_attn_varlen_func(
            query,
            key,
            value,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            dropout_p=self.softmax_dropout_p if self.training else 0,
            softmax_scale=self._get_softmax_scale(),
            causal=self.causal,
        )

        del query, key, value

        # ==========================================================================================
        # hidden_states -> (total_q, num_heads, head_dim)
        # ==========================================================================================

        hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True)
        hidden_states = hidden_states.view(-1, self.hidden_size)

        # ==========================================================================================
        # hidden_states -> (total_q, num_heads * head_dim)
        # ==========================================================================================

        hidden_states = self.c_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)

        return hidden_states

    def _prepare_qkv_for_forward_mha(
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        total_q = hidden_states.shape[0]

        hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1)
        query, key, value = hidden_states.chunk(3, dim=-1)

        return query, key, value

    def _prepare_qkv_for_forward_gqa(
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        total_q = hidden_states.shape[0]

        hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1)

        query, key, value = hidden_states.split(
            ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1
        )

        # this needs to be a reshape instead of view sadly
        query = query.reshape(total_q, -1, self.head_dim)

        return query, key, value

    def _prepare_qkv_for_forward_mqa(
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        total_q = hidden_states.shape[0]

        query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1)

        query = query.view(total_q, self.num_heads, -1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)

        return query, key, value
