from typing import Optional

import torch
from torch import Tensor
from torch import nn
from einops import rearrange

from megatron.core.transformer.enums import AttnMaskType
from megatron.core.packed_seq_params import PackedSeqParams

from gpatch.core.transformer.transformer_config import Gemma3TransformerConfig


class SPDAAttention(nn.Module):

    def __init__(
        self,
        config: Gemma3TransformerConfig,
        layer_number: int,
        attn_mask_type: AttnMaskType,
        attention_type: str,
        attention_dropout: Optional[float] = None,
        softmax_scale: Optional[float] = None,
        k_channels: Optional[int] = None,
        v_channels: Optional[int] = None,
        cp_comm_type: str = "p2p",
    ):
        super().__init__()
        self.config = config
        self.layer_number = layer_number
        self.attention_dropout = attention_dropout
        self.softmax_scale = softmax_scale

        self.num_key_value_groups = self.config.num_attention_heads // self.config.num_query_groups
        self.scaling = self.config.query_pre_attn_scalar**-0.5

                                                                                                                                     
    def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """
        This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
        num_query_groups, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
        """
        batch, num_query_groups, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_query_groups, n_rep, slen,
                                                               head_dim)
        return hidden_states.reshape(batch, num_query_groups * n_rep, slen, head_dim)

                                                                                                                                     
    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        attention_mask: Tensor,
        attn_mask_type: AttnMaskType,
        attention_bias: Tensor = None,
        packed_seq_params: PackedSeqParams = None,
    ):
        query = rearrange(query, "s b q d -> b q s d")
        key = rearrange(key, "s b h d -> b h s d")
        value = rearrange(value, "s b h d -> b h s d")
        if self.num_key_value_groups > 1:
            key = self.repeat_kv(key, self.num_key_value_groups)
            value = self.repeat_kv(value, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, :key.shape[-2]]

                                                                                                                              
                                                                      
        query = query.contiguous()
        key = key.contiguous()
        value = value.contiguous()

                                                                                                                                                  
                                                                                                                                                      
                                                                                                                                                   
        is_causal = query.shape[2] > 1 and causal_mask is None

                                                                                                               
                                                                              
        if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
            is_causal = is_causal.item()

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.attention_dropout else 0.0,
            scale=self.scaling,
            is_causal=is_causal,
        )
        attn_output = rearrange(attn_output, "b h s d -> s b h d")
        attn_output = attn_output.reshape(*attn_output.shape[0:2], -1).contiguous()

        return attn_output
