import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2Attention,
    eager_attention_forward as qwen2_eager_attention_forward,
)
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    eager_attention_forward as llama_eager_attention_forward,
    LlamaRMSNorm,
)

from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, sdpa_attention_forward


from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict, Any, Callable
from .extending_the_nope import (
    ALL_NOPE_EXTENDED_ATTENTION_FUNCTIONS,
    qwen2_custom_nope_attn,
)

import logging


logger = logging.getLogger(__name__)


class SequenceMixingLayer(nn.Module, ABC):

    def __init__(
        self,
        config,
        layer_idx: int,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
        **deprecated_kwargs,
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.config = config
        self.layer_idx = layer_idx
        self.set_context_extension_mode(
            context_extension_attn_impl=context_extension_attn_impl,
            context_extension_attn_params=context_extension_attn_params,
            context_extension_nope_params=context_extension_nope_params,
        )

    @abstractmethod
    def forward(self, *args, **kwargs) -> torch.Tensor:

        pass

    def copy_attention_weights(self, source_module: nn.Module):

        target_module = self._get_module()

        source_params = dict(source_module.named_parameters())
        target_params = dict(target_module.named_parameters())

        for param_name, source_param in source_params.items():
            if param_name in target_params:
                target_param = target_params[param_name]

                if source_param.shape == target_param.shape:
                    with torch.no_grad():
                        target_param.copy_(source_param)
                    logger.debug(
                        f"Copied {param_name} in layer {self.layer_idx}: {source_param.shape}"
                    )
                else:
                    logger.warning(
                        f"Shape mismatch for {param_name} in layer {self.layer_idx}: source {source_param.shape} vs target {target_param.shape}"
                    )
            else:
                logger.warning(
                    f"Parameter {param_name} not found in target layer {self.layer_idx}"
                )

    def set_context_extension_mode(
        self,
        context_extension_attn_impl: str,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):

        self.context_extension_attn_impl = context_extension_attn_impl
        self.use_extended_attention_fn = self.context_extension_attn_impl is not None
        self.context_extension_attn_params = context_extension_attn_params or {}
        self.context_extension_nope_params = context_extension_nope_params or {}
        self.use_custom_attention = self.use_extended_attention_fn or (
            len(self.context_extension_nope_params) > 0
        )

        if not self.use_extended_attention_fn:
            return
        assert isinstance(self.context_extension_attn_impl, Callable) or (
            self.context_extension_attn_impl in ALL_NOPE_EXTENDED_ATTENTION_FUNCTIONS
        ), (
            "Invalid context extension implementation: "
            f"{self.context_extension_attn_impl}. Available implementations:"
            f"{ALL_NOPE_EXTENDED_ATTENTION_FUNCTIONS}"
        )
        if isinstance(self.context_extension_attn_impl, str):
            self.context_extension_attn_impl = context_extension_attn_impl.lower()

    @abstractmethod
    def _get_module(self) -> nn.Module:

        pass


class Qwen3AttentionSequenceMixing(SequenceMixingLayer):

    def __init__(self, config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.attention = Qwen3Attention(config, layer_idx)

    def forward(self, *args, **kwargs) -> torch.Tensor:
        return self.attention(*args, **kwargs)

    def _get_module(self):

        return self.attention


class Qwen2NoPEAttentionSequenceMixing(SequenceMixingLayer):

    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(
            config,
            layer_idx,
            add_features_dim=add_features_dim,
            context_extension_attn_impl=context_extension_attn_impl,
            context_extension_attn_params=context_extension_attn_params,
            context_extension_nope_params=context_extension_nope_params,
        )
        self.attention = Qwen2Attention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        *args,
        add_features: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:

        original_cos, original_sin = position_embeddings
        position_embeddings = (
            torch.ones_like(original_cos),
            torch.zeros_like(original_sin),
        )
        if self.use_extended_attention_fn:

            return qwen2_custom_nope_attn(
                self.attention,
                hidden_states,
                position_embeddings,
                *args,
                use_extended_attention_fn=self.use_extended_attention_fn,
                context_extension_attn_impl=self.context_extension_attn_impl,
                context_extension_attn_params=(self.context_extension_attn_params),
                context_extension_nope_params=(self.context_extension_nope_params),
                **kwargs,
            )
        else:
            return self.attention(hidden_states, position_embeddings, *args, **kwargs)

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class Qwen2HighRoPEAttentionSequenceMixing(SequenceMixingLayer):

    def __init__(self, config, layer_idx: int, hf_channels: int = 6, **kwargs):
        super().__init__(config, layer_idx)
        self.attention = Qwen2Attention(config, layer_idx)
        self.hf_channels = hf_channels

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        *args,
        **kwargs,
    ) -> torch.Tensor:
        original_cos, original_sin = position_embeddings

        attn_scale = original_cos[0, 0, 0]

        new_cos = torch.ones_like(original_cos) * attn_scale
        new_sin = torch.zeros_like(original_sin)
        new_cos[..., : self.hf_channels] = original_cos[..., : self.hf_channels]
        new_sin[..., : self.hf_channels] = original_sin[..., : self.hf_channels]
        position_embeddings = (new_cos, new_sin)

        return self.attention(hidden_states, position_embeddings, *args, **kwargs)

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class Qwen2ConvNoPEAttentionSequenceMixing(SequenceMixingLayer):
    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        conv_kernel_size: int = 17,
        **kwargs,
    ):
        super().__init__(config, layer_idx)
        self.attention = Qwen2Attention(config, layer_idx)

        k = conv_kernel_size
        self.conv = nn.Conv1d(
            in_channels=config.hidden_size,
            out_channels=config.hidden_size,
            kernel_size=k,
            padding=0,
            bias=True,
        )

        with torch.no_grad():
            self.conv.weight.zero_()
            self.conv.bias.zero_()
            eye = torch.eye(
                config.hidden_size,
                device=self.conv.weight.device,
                dtype=self.conv.weight.dtype,
            )
            self.conv.weight[:, :, k - 1] = eye

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        *args,
        add_features: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:

        x = hidden_states.transpose(1, 2)

        k = self.conv.kernel_size[0]
        d = self.conv.dilation[0]
        pad_left = d * (k - 1)

        x = F.pad(x, (pad_left, 0))
        x = self.conv(x)

        x = x.transpose(1, 2)

        original_cos, original_sin = position_embeddings
        position_embeddings = (
            torch.ones_like(original_cos),
            torch.zeros_like(original_sin),
        )
        return self.attention(x, position_embeddings, *args, **kwargs)

    def _get_module(self):
        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class Qwen2SWANSequenceMixing(SequenceMixingLayer):

    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(
            config,
            layer_idx,
            add_features_dim=add_features_dim,
            context_extension_attn_impl=context_extension_attn_impl,
            context_extension_attn_params=context_extension_attn_params,
            context_extension_nope_params=context_extension_nope_params,
        )
        if layer_idx % 4 != 0:
            config.use_sliding_window = True
            config.sliding_window_size = 100
            config.max_window_layers = 28
        self.attention = Qwen2Attention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        *args,
        add_features: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        if self.layer_idx % 4 == 0:

            original_cos, original_sin = position_embeddings
            position_embeddings = (
                torch.ones_like(original_cos),
                torch.zeros_like(original_sin),
            )
        return self.attention(hidden_states, position_embeddings, *args, **kwargs)

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        if self.layer_idx % 4 == 0:
            self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class Qwen2AliBiAttentionSequenceMixing(SequenceMixingLayer):

    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(config, layer_idx)
        self.attention = Qwen2Attention(config, layer_idx)

        def _slopes_pow2(m: int) -> list[float]:
            start = 2 ** (-(2 ** -(math.log2(m) - 3)))
            ratio = start
            return [start * (ratio**i) for i in range(m)]

        n_heads = config.num_attention_heads
        if float(math.log2(n_heads)).is_integer():
            slopes = _slopes_pow2()
        else:
            cp2 = 2 ** math.floor(math.log2(n_heads))
            slopes = _slopes_pow2(cp2)
            slopes += _slopes_pow2(2 * cp2)[0::2][: n_heads - cp2]

        self.register_buffer("slopes", torch.tensor(slopes).view(1, n_heads, 1, 1))

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.attention.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        key_states = (
            self.attention.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        original_cos, original_sin = position_embeddings
        cos = torch.ones_like(original_cos)
        sin = torch.zeros_like(original_sin)

        if past_key_value is not None:
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(
                key_states,
                value_states,
                self.attention.layer_idx,
                cache_kwargs,
            )

        bsz, n_heads, q_len, _ = query_states.shape
        k_len = key_states.shape[-2]
        device = query_states.device
        dtype = query_states.dtype
        neg_inf = torch.finfo(dtype).min

        k_idx = torch.arange(k_len, device=device)
        q_idx = torch.arange(q_len, device=device) + (k_len - q_len)
        rel = k_idx.view(1, 1, 1, k_len) - q_idx.view(1, 1, q_len, 1)
        rel = rel.to(dtype)
        slopes = self.slopes.to(device=device, dtype=dtype)
        alibi_biased_mask = slopes * rel
        alibi_biased_mask = alibi_biased_mask.expand(bsz, n_heads, q_len, k_len)

        causal_bool = torch.ones(q_len, k_len, device=device, dtype=torch.bool)
        causal_bool = torch.triu(causal_bool, diagonal=1)
        alibi_biased_mask.masked_fill_(causal_bool, neg_inf)

        base_mask = None
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                m = attention_mask.to(dtype)
                base_mask = (1.0 - m)[:, None, None, :] * neg_inf
            elif attention_mask.dim() == 3:
                m = attention_mask.to(dtype)
                base_mask = (1.0 - m)[:, None, :, :] * neg_inf
            elif attention_mask.dim() == 4:
                if attention_mask.dtype == torch.bool:
                    base_mask = torch.zeros_like(attention_mask, dtype=dtype)
                    base_mask = base_mask.masked_fill(~attention_mask, neg_inf)
                else:
                    base_mask = attention_mask.to(dtype)

            if base_mask is not None:
                if base_mask.shape[-2] != q_len or base_mask.shape[-1] != k_len:
                    base_mask = base_mask[..., :q_len, :k_len]
                if base_mask.shape[1] == 1:
                    base_mask = base_mask.expand(bsz, n_heads, q_len, k_len)
                alibi_biased_mask = alibi_biased_mask + base_mask

        attn_output, _ = sdpa_attention_forward(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask=alibi_biased_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            is_causal=False,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)

        return attn_output, None

    def _get_module(self):
        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class Qwen3NoPEAttentionSequenceMixing(SequenceMixingLayer):

    def __init__(self, config, layer_idx: int, add_features_dim: Optional[int] = None):
        super().__init__(config, layer_idx)
        self.attention = Qwen3Attention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        *args,
        add_features: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:

        original_cos, original_sin = position_embeddings
        position_embeddings = (
            torch.ones_like(original_cos),
            torch.zeros_like(original_sin),
        )
        return self.attention(hidden_states, position_embeddings, *args, **kwargs)

    def _get_module(self):

        return self.attention


class LlamaNoPEAttentionSequenceMixing(SequenceMixingLayer):

    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(
            config,
            layer_idx,
            add_features_dim=add_features_dim,
            context_extension_attn_impl=context_extension_attn_impl,
            context_extension_attn_params=context_extension_attn_params,
            context_extension_nope_params=context_extension_nope_params,
        )
        self.attention = LlamaAttention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        *args,
        add_features: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:

        original_cos, original_sin = position_embeddings
        position_embeddings = (
            torch.ones_like(original_cos),
            torch.zeros_like(original_sin),
        )
        return self.attention(hidden_states, position_embeddings, *args, **kwargs)

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class Qwen2CosineNoPEAttentionSequenceMixing(SequenceMixingLayer):

    def __init__(self, config, layer_idx: int, add_features_dim: Optional[int] = None):
        super().__init__(config, layer_idx)
        self.attention = Qwen2Attention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.attention.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        key_states = (
            self.attention.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        query_states = query_states / (
            query_states.pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(1e-6)
        )
        key_states = key_states / (
            key_states.pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(1e-6)
        )

        original_cos, original_sin = position_embeddings
        position_embeddings = (
            torch.ones_like(original_cos),
            torch.zeros_like(original_sin),
        )
        cos, sin = position_embeddings

        if past_key_value is not None:

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.attention.layer_idx, cache_kwargs
            )

        sliding_window = None
        if (
            self.attention.config.use_sliding_window
            and getattr(self.attention.config, "sliding_window", None) is not None
            and self.attention.layer_idx >= self.attention.config.max_window_layers
        ):
            sliding_window = self.attention.config.sliding_window

        attention_interface = qwen2_eager_attention_forward
        if self.attention.config._attn_implementation != "eager":
            if self.attention.config._attn_implementation == "sdpa" and kwargs.get(
                "output_attentions", False
            ):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[
                    self.attention.config._attn_implementation
                ]

        attn_output, attn_weights = attention_interface(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            sliding_window=sliding_window,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)
        return attn_output, attn_weights

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class LlamaQKNormNoPEAttentionSequenceMixing(SequenceMixingLayer):
    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(config, layer_idx)
        self.attention = LlamaAttention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.attention.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        key_states = (
            self.attention.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        key_states = F.rms_norm(
            key_states, (key_states.shape[-1],), eps=self.config.rms_norm_eps
        )
        query_states = F.rms_norm(
            query_states, (query_states.shape[-1],), eps=self.config.rms_norm_eps
        )

        cos, sin = position_embeddings

        if past_key_value is not None:

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.attention.layer_idx, cache_kwargs
            )

        attention_interface: Callable = llama_eager_attention_forward
        if self.attention.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                self.attention.config._attn_implementation
            ]

        attn_output, attn_weights = attention_interface(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)
        return attn_output, attn_weights

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class LlamaQKLearnableRMSNormAttentionSequenceMixing(SequenceMixingLayer):
    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(config, layer_idx)
        self.attention = LlamaAttention(config, layer_idx)

        self.q_norm = LlamaRMSNorm(
            config.num_attention_heads * self.config.head_dim, config.rms_norm_eps
        )
        self.k_norm = LlamaRMSNorm(
            config.num_key_value_heads * self.config.head_dim, config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.q_norm(self.attention.q_proj(hidden_states))
            .view(hidden_shape)
            .transpose(1, 2)
        )
        key_states = (
            self.k_norm(self.attention.k_proj(hidden_states))
            .view(hidden_shape)
            .transpose(1, 2)
        )
        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        cos, sin = position_embeddings

        if past_key_value is not None:

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.attention.layer_idx, cache_kwargs
            )

        attention_interface: Callable = llama_eager_attention_forward
        if self.attention.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                self.attention.config._attn_implementation
            ]

        attn_output, attn_weights = attention_interface(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)
        return attn_output, attn_weights

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class LlamaKNormNoPEAttentionSequenceMixing(SequenceMixingLayer):
    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(config, layer_idx)
        self.attention = LlamaAttention(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.attention.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        key_states = (
            self.attention.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )
        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        key_states = F.rms_norm(
            key_states, (key_states.shape[-1],), eps=self.config.rms_norm_eps
        )

        cos, sin = position_embeddings

        if past_key_value is not None:

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.attention.layer_idx, cache_kwargs
            )

        attention_interface: Callable = llama_eager_attention_forward
        if self.attention.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                self.attention.config._attn_implementation
            ]

        attn_output, attn_weights = attention_interface(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)
        return attn_output, attn_weights

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class LlamaKLearnableRMSNormAttentionSequenceMixing(SequenceMixingLayer):
    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(config, layer_idx)
        self.attention = LlamaAttention(config, layer_idx)

        self.k_norm = LlamaRMSNorm(
            config.num_key_value_heads * self.config.head_dim, config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.attention.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        key_states = (
            self.k_norm(self.attention.k_proj(hidden_states))
            .view(hidden_shape)
            .transpose(1, 2)
        )
        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        cos, sin = position_embeddings

        if past_key_value is not None:

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.attention.layer_idx, cache_kwargs
            )

        attention_interface: Callable = llama_eager_attention_forward
        if self.attention.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                self.attention.config._attn_implementation
            ]

        attn_output, attn_weights = attention_interface(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)
        return attn_output, attn_weights

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


class LlamaQLearnableRMSNormAttentionSequenceMixing(SequenceMixingLayer):
    def __init__(
        self,
        config,
        layer_idx: int,
        add_features_dim: Optional[int] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(config, layer_idx)
        self.attention = LlamaAttention(config, layer_idx)

        self.q_norm = LlamaRMSNorm(
            config.num_attention_heads * self.config.head_dim, config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.attention.head_dim)

        query_states = (
            self.q_norm(self.attention.q_proj(hidden_states))
            .view(hidden_shape)
            .transpose(1, 2)
        )

        key_states = (
            self.attention.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        value_states = (
            self.attention.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        )

        cos, sin = position_embeddings

        if past_key_value is not None:

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.attention.layer_idx, cache_kwargs
            )

        attention_interface: Callable = llama_eager_attention_forward
        if self.attention.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[
                self.attention.config._attn_implementation
            ]

        attn_output, attn_weights = attention_interface(
            self.attention,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=(
                0.0 if not self.attention.training else self.attention.attention_dropout
            ),
            scaling=self.attention.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.attention.o_proj(attn_output)
        return attn_output, attn_weights

    def _get_module(self):

        return self.attention

    def set_softmax_scale(self, softmax_scale: float):
        self.attention.scaling = self.attention.head_dim**-0.5 * softmax_scale


SEQUENCE_MIXING_REGISTRY = {
    "qwen3_attention": Qwen3AttentionSequenceMixing,
    "qwen3_nope_attention": Qwen3NoPEAttentionSequenceMixing,
    "qwen2_nope_attention": Qwen2NoPEAttentionSequenceMixing,
    "qwen2_conv_nope_attention": Qwen2ConvNoPEAttentionSequenceMixing,
    "qwen2_high_rope_attention": Qwen2HighRoPEAttentionSequenceMixing,
    "qwen2_swan_attention": Qwen2SWANSequenceMixing,
    "qwen2_cosine_attention": Qwen2CosineNoPEAttentionSequenceMixing,
    "qwen2_alibi_attention": Qwen2AliBiAttentionSequenceMixing,
    "llama_nope_attention": LlamaNoPEAttentionSequenceMixing,
    "llama_qknorm_attention": LlamaQKNormNoPEAttentionSequenceMixing,
    "llama_knorm_attention": LlamaKNormNoPEAttentionSequenceMixing,
    "llama_qk_learnable_rmsnorm_attention": LlamaQKLearnableRMSNormAttentionSequenceMixing,
    "llama_k_learnable_rmsnorm_attention": LlamaKLearnableRMSNormAttentionSequenceMixing,
    "llama_q_learnable_rmsnorm_attention": LlamaQLearnableRMSNormAttentionSequenceMixing,
}


def create_sequence_mixing_layer(
    config,
    mixing_type: str,
    layer_idx: int,
    add_features_dim: Optional[int] = None,
    **kwargs,
):

    if add_features_dim is not None:
        raise ValueError(
            "Additional feature injection was deprecated for " "higher throughput."
        )
    assert mixing_type in SEQUENCE_MIXING_REGISTRY, (
        f"Unknown mixing type: {mixing_type}. "
        f"Available types: {list(SEQUENCE_MIXING_REGISTRY.keys())}"
    )

    return SEQUENCE_MIXING_REGISTRY[mixing_type](
        config=config, layer_idx=layer_idx, **kwargs
    )
