# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from __future__ import annotations

import math
import warnings
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from einops import rearrange, repeat
from torch.nn import functional as F

from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
from fla.ops.gated_delta_rule import (
    chunk_gated_delta_rule,
    fused_recurrent_gated_delta_rule,
)
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule


if TYPE_CHECKING:
    from transformers.processing_utils import Unpack

    from fla.models.utils import Cache


@torch.compile
def elu_p1(x):
    return (F.elu(x, 1.0, False) + 1.0).to(x)


@torch.compile
def sum_norm(x):
    return (x / x.sum(-1, keepdim=True)).to(x)


class GatedDeltaStack(nn.Module):
    """
    The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464).  # noqa

    Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.

    Parameter alloation when use_gate=True:
        - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
        - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
        - Others are ignorably small.
        - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
    NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.

    Parameter allocation when use_gate=False:
        - 1 * hidden_size * hidden_size for the q_proj and k_proj each
        - 2 * hidden_size * hidden_size for the v_proj and o_proj each
        - Others are ignorably small.
        - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size

    Args:
        hidden_size (int, Optional):
            The hidden size of the input. Default: 2048.
        expand_v (float, Optional):
            The expansion ratio for the value dim. Default: 2.0.
        head_dim (int, Optional):
            The dimension of each head. Default: 256.
        num_heads (int, Optional):
            The number of heads. Default: 4.
        num_v_heads (int, Optional):
            The number of heads for the value projection, equal to `num_heads` if `None`.
            GVA is applied if `num_v_heads` > `num_heads`. Default: `None`.
        mode (str, Optional):
            Which Gated DeltaNet kernel to use.
            Currently available: `chunk` and `fused_recurrent`.
            Default: `chunk`.
        use_beta (bool, Optional):
            Whether to use beta. Default: `True`.
        use_gate (bool, Optional):
            Whether to use output gate. Default: `True`.
        use_short_conv (bool, Optional):
            Whether to use short convolutions. Default: `True`.
        allow_neg_eigval (bool, Optional):
            Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
            See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
        conv_size (int, Optional):
            The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
        conv_bias (bool, Optional):
            Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
        layer_idx (int, Optional):
            The index of the layer. Default: None.
        norm_eps (float, Optional):
            The epsilon value for the normalization layer. Default: 1e-5.
    """

    def __init__(
        self,
        hidden_size: int = 2048,
        expand_v: float = 2,
        head_dim: int = 256,
        num_heads: int = 4,
        num_v_heads: int = None,
        mode: str = "chunk",
        use_gate: bool = True,
        use_short_conv: bool = True,
        allow_neg_eigval: bool = False,
        conv_size: int = 4,
        conv_bias: bool = False,
        layer_idx: int = None,
        norm_eps: float = 1e-5,
        stack_size: int = 64,
        **kwargs,
    ) -> GatedDeltaStack:
        super().__init__()

        self.stack_size = stack_size
        self.stack_kappa = nn.Parameter(torch.tensor(10.0))

        self.mode = mode
        self.allow_neg_eigval = allow_neg_eigval
        self.hidden_size = hidden_size
        self.expand_v = expand_v

        self.use_gate = use_gate
        self.use_short_conv = use_short_conv
        self.conv_size = conv_size
        self.conv_bias = conv_bias

        self.head_dim = head_dim
        self.num_heads = num_heads
        self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads

        self.head_k_dim = head_dim
        self.head_v_dim = int(self.head_dim * self.expand_v)
        self.key_dim = int(self.num_heads * self.head_k_dim)
        self.value_dim = int(self.num_v_heads * self.head_v_dim)
        self.layer_idx = layer_idx

        # Consistency check: Ensure expand_v produces integer values
        if not math.isclose(
            self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5
        ):
            raise ValueError(
                f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
                f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear.",
            )
        if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
            raise ValueError(
                f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.",
            )

        if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
            raise ValueError(
                f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
                f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated.",
            )
        assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."

        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
        self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
        self.b_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
        self.action_proj = nn.Linear(hidden_size, self.num_v_heads * 3, bias=False)

        A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16)
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True
        # hard coded for now
        dt_min = 0.001
        dt_max = 0.1
        dt_init_floor = 1e-4
        dt = torch.exp(
            torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min),
        )
        dt = torch.clamp(dt, min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        self.dt_bias = nn.Parameter(inv_dt)
        # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
        # name.endswith("bias") in param_grouping.py
        self.dt_bias._no_weight_decay = True

        if use_short_conv:
            self.conv_size = conv_size
            self.q_conv1d = ShortConvolution(
                hidden_size=self.key_dim,
                kernel_size=conv_size,
                bias=conv_bias,
                activation="silu",
            )
            self.k_conv1d = ShortConvolution(
                hidden_size=self.key_dim,
                kernel_size=conv_size,
                bias=conv_bias,
                activation="silu",
            )
            self.v_conv1d = ShortConvolution(
                hidden_size=self.value_dim,
                kernel_size=conv_size,
                bias=conv_bias,
                activation="silu",
            )
        else:
            warnings.warn(
                "ShortConvolution is crucial to the performance. "
                "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing.",
            )
        if use_gate:
            self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
            self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
        else:
            self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps, dtype=torch.float32)
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)


    def _generate_soft_pointer(
        self,
        actions: torch.Tensor,
        prev_ptr: torch.Tensor | None = None,
        cu_seqlens: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            actions: [1, Total_Tokens, H, 3] (Flattened) OR [B, T, H, 3] (Batch)
            cu_seqlens: [B+1] Optional. Used to reset pointer for flattened inputs.
        """
        # 1. High-Precision Action Decoding
        action_probs = F.softmax(actions.float(), dim=-1)
        push_prob = action_probs[..., 0]
        pop_prob = action_probs[..., 1]

        delta_ptr = push_prob - pop_prob

        # 2. Segment-Wise Integration
        if cu_seqlens is not None:
            # FLATTENED MODE: [1, Total_Tokens, H]
            # We need to cumsum, but reset at every boundary in cu_seqlens.
            # Trick: Global Cumsum - Offset Correction

            # Remove batch dim: [Total_Tokens, H]
            delta_flat = delta_ptr.squeeze(0)
            pointer_change_flat = torch.cumsum(delta_flat, dim=0)

            # Gather the values at the start of each document to use as offsets
            # cu_seqlens[:-1] gives starting indices.
            # We need to broadcast these offsets to the full sequence.
            # This is slightly complex to vectorize perfectly in python.

            # Alternative: Since we are in Unpadding mode, we can just loop over the batch
            # efficiently because B is usually small (e.g. 4-64), or use a kernel.
            # Given Python constraints, looping over segments is often faster than
            # trying to construct a massive gather/scatter index tensor for variable lengths.

            pointer_change_list = []
            prev_ptr_list = []

            # We assume prev_ptr is [B, H] corresponding to the batches in cu_seqlens
            if prev_ptr is None:
                prev_ptr = torch.zeros(
                    (len(cu_seqlens) - 1, self.num_heads),
                    dtype=torch.float32,
                    device=actions.device,
                )
            else:
                prev_ptr = prev_ptr.float()

            for i in range(len(cu_seqlens) - 1):
                start, end = cu_seqlens[i], cu_seqlens[i + 1]
                seg_delta = delta_flat[start:end]
                seg_cumsum = torch.cumsum(seg_delta, dim=0)
                pointer_change_list.append(seg_cumsum)

            # Reconstruct the flat tensor
            pointer_change = torch.cat(pointer_change_list, dim=0).unsqueeze(
                0
            )  # [1, Total, H]

            # We need to align prev_ptr to the flattened sequence for step 3
            # Expand prev_ptr[i] to match the length of segment[i]
            prev_pos_list = []
            for i in range(len(cu_seqlens) - 1):
                length = cu_seqlens[i + 1] - cu_seqlens[i]
                p = prev_ptr[i].unsqueeze(0).expand(length, -1)  # [L, H]
                prev_pos_list.append(p)
            prev_pos_flat = torch.cat(prev_pos_list, dim=0).unsqueeze(
                0
            )  # [1, Total, H]

            current_pos = prev_pos_flat + pointer_change
            prev_pos = prev_pos_flat

            # Extract last states for cache update
            # The last state of segment i is at index cu_seqlens[i+1]-1
            last_indices = cu_seqlens[1:] - 1
            last_ptr = current_pos[0, last_indices]  # [B, H]

        else:
            # BATCH MODE (Standard): [B, T, H]
            pointer_change = torch.cumsum(delta_ptr, dim=1)

            if prev_ptr is None:
                prev_ptr = torch.zeros(
                    (actions.size(0), actions.size(2)),
                    dtype=torch.float32,
                    device=actions.device,
                )
            else:
                prev_ptr = prev_ptr.float()

            prev_pos = prev_ptr.unsqueeze(1)
            current_pos = prev_pos + pointer_change
            last_ptr = current_pos[:, -1]

        # 3. Construct Full Sequence (Prev + Current)
        # We need [B, T+1, H] to handle the 'ghost' previous state for the kernel.
        # But wait, in Flattened mode, we can't just cat(prev, current) because T varies per batch.

        # KEY INSIGHT: The kernels (chunk/fused) handle the "Previous State" via the 'initial_state' argument.
        # We DO NOT need to prepend the previous state to the sequence tensor itself for the kernel.
        # We only need it to compute the Gaussian keys for the current step.

        # However, to compute the "Pop" key (which is K_stack[t-1]), we do need the previous.
        # Let's handle this unification.

        # Combined Position for Gaussian Kernel
        # We calculate the Gaussian for 'current_pos' and 'prev_pos' (the t-1 for the very first token) separately if needed.
        # Actually, let's just compute for 'current_pos' first.

        def compute_gaussian(pos):
            scale = 2 * torch.pi / self.stack_size
            ptr_angle = pos.unsqueeze(-1) * scale
            grid_idx = torch.arange(
                self.stack_size, device=actions.device, dtype=torch.float32
            ).view(1, 1, 1, -1)


            grid_angle = grid_idx * scale

            # Calculate Angular Difference (theta)
            # Using sin(theta/2)^2 is numerically superior to (cos(theta) - 1)
            # Identity: cos(x) - 1 = -2 * sin(x/2)^2
            angle_diff = ptr_angle - grid_angle

            # Base Sharpness (Learnable)
            # This parameter now represents "sharpness per index step"
            # rather than "sharpness per radian".
            base_kappa = F.softplus(self.stack_kappa).float()

            # Dynamic Scaling: kappa_effective scales with S^2
            # We normalize by (2*pi)^2 so base_kappa behaves like 1/sigma^2 in index space.
            # Factor: S^2 / (2*pi^2) cancels out the angle conversion terms.
            scaling_factor = (self.stack_size ** 2) / (2 * (torch.pi ** 2))
            kappa_effective = base_kappa * scaling_factor

            # Clamp to prevent overflow in BF16 exp()
            kappa_effective = torch.clamp(kappa_effective, max=1e4)

            # Logits = kappa * (cos - 1)
            #        = kappa * (-2 * sin^2(diff/2))
            #        ≈ -base_kappa * index_diff^2  (Scale Invariant!)
            logits = -2.0 * kappa_effective * torch.sin(angle_diff / 2)**2
            all_k = F.softmax(logits, dim=-1).to(dtype=actions.dtype)
            return all_k

            # cos_sim = torch.cos(ptr_angle - grid_angle)
            # kappa_val = F.softplus(self.stack_kappa).float()
            # kappa_val = torch.clamp(kappa_val, max=100.0)
            # logits = kappa_val * (cos_sim - 1)
            # return F.softmax(logits, dim=-1).to(dtype=actions.dtype)

        k_stack = compute_gaussian(current_pos)  # [1, Total, H, S] or [B, T, H, S]

        # We also need k_prev_global (the state at t=-1) for the very first token of each segment.
        if cu_seqlens is not None:
            # prev_ptr is [B, H]
            k_prev_global = compute_gaussian(prev_ptr.unsqueeze(1))  # [B, 1, H, S]
        else:
            k_prev_global = compute_gaussian(prev_ptr.unsqueeze(1))  # [B, 1, H, S]

        return (
            k_stack,
            push_prob.to(actions.dtype),
            pop_prob.to(actions.dtype),
            last_ptr,
            k_prev_global,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        past_key_values: Cache | None = None,
        use_cache: bool | None = False,
        output_attentions: bool | None = False,
        **kwargs: Unpack[dict],
    ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:

        # 1. Setup
        if attention_mask is not None:
            assert len(attention_mask.shape) == 2

        batch_size, q_len, _ = hidden_states.shape
        mode = "fused_recurrent" if (q_len <= 64 and not self.training) else self.mode

        # 2. History Retrieval
        last_state = None
        if past_key_values is not None and len(past_key_values) > self.layer_idx:
            last_state = past_key_values[self.layer_idx]

        recurrent_state, stack_recurrent_state, stack_ptr_state = None, None, None
        if last_state is not None:
            recurrent_state, stack_recurrent_state, stack_ptr_state = last_state[
                "recurrent_state"
            ]

        # 3. Unpadding (Flattening)
        # We assume cu_seqlens might be passed by external caller, or we generate it.
        cu_seqlens = kwargs.get("cu_seqlens")
        indices = None

        if attention_mask is not None:
            # Generate unpadding indices
            indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
            # Flatten: [B, T, D] -> [1, Total_Valid, D]
            hidden_states = index_first_axis(
                rearrange(hidden_states, "b s ... -> (b s) ..."), indices
            ).unsqueeze(0)

        # 4. Pointer Generation
        # hidden_states is [1, Total, D] (if unpadded) or [B, T, D]
        action_logits = self.action_proj(hidden_states)
        actions = rearrange(
            action_logits, "b t (h d) -> b t h d", h=self.num_heads, d=3
        )

        k_stack, push_prob, pop_prob, new_ptr_state, k_prev_global = (
            self._generate_soft_pointer(
                actions, prev_ptr=stack_ptr_state, cu_seqlens=cu_seqlens
            )
        )

        # 5. Projections
        if self.use_short_conv:
            conv_state_q, conv_state_k, conv_state_v = None, None, None
            if last_state is not None:
                conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"]
            q, conv_state_q = self.q_conv1d(
                x=self.q_proj(hidden_states),
                cache=conv_state_q,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
            )
            k, conv_state_k = self.k_conv1d(
                x=self.k_proj(hidden_states),
                cache=conv_state_k,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
            )
            v, conv_state_v = self.v_conv1d(
                x=self.v_proj(hidden_states),
                cache=conv_state_v,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
            )
        else:
            q = F.silu(self.q_proj(hidden_states))
            k = F.silu(self.k_proj(hidden_states))
            v = F.silu(self.v_proj(hidden_states))

        q, k = map(
            lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k)
        )
        v = rearrange(v, "... (h d) -> ... h d", d=self.head_v_dim)

        if self.num_v_heads > self.num_heads:
            q, k = map(
                lambda x: repeat(
                    x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads
                ),
                (q, k),
            )

        beta = self.b_proj(hidden_states).sigmoid()
        if self.allow_neg_eigval:
            beta = beta * 2.0

        g = -self.A_log.float().exp() * F.softplus(
            self.a_proj(hidden_states).float() + self.dt_bias
        )

        # 6. Interleaving
        k_stack = k_stack.to(dtype=v.dtype)

        # Construct k_pop
        if cu_seqlens is not None:
            # Flattened Mode: Shift and patch boundaries
            k_shifted = torch.roll(k_stack, shifts=1, dims=1)
            start_indices = cu_seqlens[:-1]
            k_prev_flat = k_prev_global.squeeze(1)  # [B, H, S]
            k_shifted[0, start_indices] = k_prev_flat
            k_pop = k_shifted
        else:
            # Batch Mode
            if k_stack.shape[1] > 1:
                k_pop = torch.cat([k_prev_global, k_stack[:, :-1]], dim=1)
            else:
                k_pop = k_prev_global

        v_pop = torch.zeros_like(v)
        beta_pop = pop_prob.unsqueeze(-1)
        k_push = k_stack
        v_push = v
        beta_push = push_prob.unsqueeze(-1)

        def interleave(a, b):
            combined = torch.stack([a, b], dim=2)
            return combined.flatten(1, 2)

        # Force Contiguous!
        k_interleaved = interleave(k_pop, k_push).contiguous()
        v_interleaved = interleave(v_pop, v_push).contiguous()
        beta_interleaved = interleave(beta_pop, beta_push).squeeze(-1).contiguous()

        # FIX: Interleave g so it matches the 2x length of q/k/v
        # g_interleaved = interleave(g, g).squeeze(-1).contiguous()

        # FIX: Double cu_seqlens because we doubled the sequence length
        cu_seqlens_2x = None
        if cu_seqlens is not None:
            cu_seqlens_2x = cu_seqlens * 2

        # 7. Kernel
        if mode == "chunk":
            stack_out_2x, new_stack_recurrent_state = chunk_delta_rule(
                q=k_interleaved,
                k=k_interleaved,
                v=v_interleaved,
                beta=beta_interleaved,
                # g=g_interleaved, # Pass interleaved g
                initial_state=stack_recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens_2x,  # Pass doubled cu_seqlens
                use_qk_l2norm_in_kernel=True,
            )
        elif mode == "fused_recurrent":
            stack_out_2x, new_stack_recurrent_state = fused_recurrent_delta_rule(
                q=k_interleaved,
                k=k_interleaved,
                v=v_interleaved,
                beta=beta_interleaved,
                # g=g_interleaved, # Pass interleaved g
                initial_state=stack_recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens_2x,  # Pass doubled cu_seqlens
                use_qk_l2norm_in_kernel=True,
            )
        else:
            raise NotImplementedError(f"Not supported mode `{mode}`.")

        stack_out = stack_out_2x[:, 1::2].contiguous()
        v_aug = v + stack_out
        # o = stack_out

        if mode == "fused_recurrent":
            o, recurrent_state = fused_recurrent_gated_delta_rule(
                q=q,
                k=k,
                v=v_aug,
                g=g,
                beta=beta,
                initial_state=recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
                use_qk_l2norm_in_kernel=True,
            )
        elif mode == "chunk":
            o, recurrent_state = chunk_gated_delta_rule(
                q=q,
                k=k,
                v=v_aug,
                g=g,
                beta=beta,
                initial_state=recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
                use_qk_l2norm_in_kernel=True,
            )
        else:
            raise NotImplementedError(f"Not supported mode `{mode}`.")

        # 8. Cache Update
        if past_key_values is not None:
            past_key_values.update(
                recurrent_state=(
                    recurrent_state,
                    new_stack_recurrent_state,
                    new_ptr_state,
                ),
                conv_state=(
                    (conv_state_q, conv_state_k, conv_state_v)
                    if self.use_short_conv
                    else None
                ),
                layer_idx=self.layer_idx,
                offset=q_len,
            )

        # 9. Output Norm & Projection
        if self.use_gate:
            g_out = rearrange(
                self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim
            )
            o = self.o_norm(o, g_out)
        else:
            o = self.o_norm(o)

        o = rearrange(o, "b t h d -> b t (h d)")
        o = self.o_proj(o)

        # 10. Repad / Unflatten
        if attention_mask is not None:
            o = pad_input(o.squeeze(0), indices, batch_size, q_len)

        return o, None, past_key_values
