# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from einops import rearrange
from torch.nn import functional as F

from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule

if TYPE_CHECKING:
    from transformers.processing_utils import Unpack

    from fla.models.utils import Cache


def elu_p1(x):
    return (F.elu(x, 1.0, False) + 1.0).to(x)


def sum_norm(x):
    return (x / x.sum(-1, keepdim=True)).to(x)


class DeltaStack(nn.Module):
    r"""
    The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484).  # noqa:
    DeltaStack was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa

    Args:
        mode (str, Optional):
            Which DeltaStack kernel to use.
            Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
            Default: `chunk`.
        hidden_size (int, Optional):
            The hidden size of the input. Default: 1024.
        expand_k (float, Optional):
            The expansion ratio for the key dim. Default: 1.0.
        expand_v (float, Optional):
            The expansion ratio for the value dim. Default: 1.0.
        num_heads (int, Optional):
            The number of heads. Default: 4.
        use_beta (bool, Optional):
            Whether to use beta. Default: `True`.
        use_gate (bool, Optional):
            Whether to use output gate. Default: `False`.
        use_short_conv (bool, Optional):
            Whether to use short convolutions. Default: `True`.
        conv_size (int, Optional):
            The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
        conv_bias (bool, Optional):
            Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
        allow_neg_eigval (bool, Optional):
            Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
            See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
        layer_idx (int, Optional):
            The index of the layer. Default: None.
        norm_eps (float, Optional):
            The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
        qk_activation (str, Optional):
            The activation function for the query and key. Default: `silu`.
        qk_norm (str, Optional):
            The normalization method for the query and key. Default: `l2`.
    """

    def __init__(
        self,
        mode: str = "chunk",
        d_model: int = None,
        hidden_size: int = 1024,
        expand_k: float = 1.0,
        expand_v: float = 1.0,
        num_heads: int = 4,
        use_beta: bool = True,
        use_gate: bool = False,
        use_short_conv: bool = True,
        conv_size: int = 4,
        conv_bias: bool = False,
        allow_neg_eigval: bool = False,
        layer_idx: int = None,
        qk_activation: str = "silu",
        qk_norm: str = "l2",
        norm_eps: float = 1e-5,
        stack_size: int = 64,
        **kwargs,
    ) -> DeltaStack:
        super().__init__()

        self.stack_size = stack_size
        # self.stack_kappa = nn.Parameter(torch.tensor(20.0))
        self.stack_kappa = nn.Parameter(torch.tensor([20.0]))  # for fully shard

        self.mode = mode
        self.qk_activation = qk_activation
        self.qk_norm = qk_norm

        assert self.qk_activation in ["silu", "relu", "elu", "identity"]
        assert self.qk_norm in ["l2", "sum"]

        if d_model is not None:
            hidden_size = d_model
        self.hidden_size = hidden_size
        self.expand_k = expand_k
        self.expand_v = expand_v
        self.num_heads = num_heads
        self.use_gate = use_gate
        self.use_short_conv = use_short_conv
        self.conv_size = conv_size
        self.conv_bias = conv_bias
        self.allow_neg_eigval = allow_neg_eigval

        self.key_dim = int(hidden_size * expand_k)
        self.value_dim = int(hidden_size * expand_v)
        self.head_k_dim = self.key_dim // num_heads
        self.head_v_dim = self.value_dim // num_heads
        self.layer_idx = layer_idx

        if mode == "fused_chunk":
            raise NotImplementedError(
                "fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead."
            )
        assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
        assert (
            self.key_dim % num_heads == 0
        ), f"key dim must be divisible by num_heads of {num_heads}"
        assert (
            self.value_dim % num_heads == 0
        ), f"value dim must be divisible by num_heads of {num_heads}"

        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
        self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
        self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
        self.action_proj = nn.Linear(hidden_size, self.num_heads * 3, bias=False)

        self.use_beta = use_beta
        if self.use_beta:
            self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
        if use_short_conv:
            self.conv_size = conv_size
            self.q_conv1d = ShortConvolution(
                hidden_size=self.key_dim,
                kernel_size=conv_size,
                bias=conv_bias,
                activation="silu" if qk_activation == "silu" else None,
            )
            self.k_conv1d = ShortConvolution(
                hidden_size=self.key_dim,
                kernel_size=conv_size,
                bias=conv_bias,
                activation="silu" if qk_activation == "silu" else None,
            )
            self.v_conv1d = ShortConvolution(
                hidden_size=self.value_dim,
                kernel_size=conv_size,
                bias=conv_bias,
                activation="silu",
            )
        else:
            warnings.warn(
                "ShortConvolution is crucial to the performance. "
                "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing.",
            )
        if use_gate:
            self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
            self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
        else:
            self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps, dtype=torch.float32)

        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)

    def _generate_soft_pointer(
        self,
        actions: torch.Tensor,
        prev_ptr: torch.Tensor | None = None,
        cu_seqlens: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generates stack pointers using a Linear Laplace Kernel.
        """
        # 1. Decode Actions
        action_probs = F.softmax(actions.float(), dim=-1)
        push_prob = action_probs[..., 0]
        pop_prob = action_probs[..., 1]

        delta_ptr = push_prob - pop_prob

        # 2. Integrate (Get RAW Positions)
        if cu_seqlens is not None:
            # FLATTENED MODE
            delta_flat = delta_ptr.squeeze(0)
            # pointer_change_flat = torch.cumsum(delta_flat, dim=0)

            pointer_change_list = []
            prev_ptr_list = []

            if prev_ptr is None:
                prev_ptr = torch.zeros(
                    (len(cu_seqlens) - 1, self.num_heads),
                    dtype=torch.float32,
                    device=actions.device,
                )
            else:
                prev_ptr = prev_ptr.float()

            for i in range(len(cu_seqlens) - 1):
                start, end = cu_seqlens[i], cu_seqlens[i + 1]
                seg_delta = delta_flat[start:end]
                seg_cumsum = torch.cumsum(seg_delta, dim=0)
                pointer_change_list.append(seg_cumsum)

                length = end - start
                p = prev_ptr[i].unsqueeze(0).expand(length, -1)
                prev_ptr_list.append(p)

            pointer_change = torch.cat(pointer_change_list, dim=0).unsqueeze(0)
            prev_pos_flat = torch.cat(prev_ptr_list, dim=0).unsqueeze(0)

            current_pos = prev_pos_flat + pointer_change

            # CACHE LOGIC: Extract UNCLAMPED last states (Preserves Ghost Depth)
            last_indices = cu_seqlens[1:] - 1
            new_ptr_state_raw = current_pos[0, last_indices]

        else:
            # BATCH MODE
            pointer_change = torch.cumsum(delta_ptr, dim=1)

            if prev_ptr is None:
                prev_ptr = torch.zeros(
                    (actions.size(0), actions.size(2)),
                    dtype=torch.float32,
                    device=actions.device,
                )
            else:
                prev_ptr = prev_ptr.float()

            prev_pos = prev_ptr.unsqueeze(1)
            current_pos = prev_pos + pointer_change

            # CACHE LOGIC: Extract UNCLAMPED last state
            new_ptr_state_raw = current_pos[:, -1]

        # 3. Clamping for Kernel (Physical Memory Access)
        # Prevents "teleporting" from 0 to S-1
        current_pos_clamped = torch.clamp(current_pos, 0, self.stack_size - 1.0)

        if prev_ptr is not None:
            prev_ptr_clamped = torch.clamp(prev_ptr, 0, self.stack_size - 1.0)
        else:
            prev_ptr_clamped = torch.zeros_like(current_pos_clamped)

        # 4. Laplace Kernel (Scale-Invariant, Linear Topology)
        # Grid: [1, 1, 1, S]
        grid_idx = torch.arange(
            self.stack_size, device=actions.device, dtype=torch.float32
        ).view(1, 1, 1, -1)

        # Learnable Sharpness
        sharpness = F.softplus(self.stack_kappa)

        def compute_laplace(ptr):
            # ptr: [..., H] -> [..., H, 1]
            p = ptr.unsqueeze(-1)

            # L1 Distance (Linear Topology)
            dist = torch.abs(grid_idx - p)

            # Logits = -sharpness * distance
            logits = -sharpness * dist

            return F.softmax(logits, dim=-1).to(dtype=actions.dtype)

        # Generate K_STACK using Clamped Current Position
        k_stack = compute_laplace(current_pos_clamped)

        # Generate K_PREV using Clamped Previous Pointer
        if cu_seqlens is not None:
            k_prev_global = compute_laplace(prev_ptr_clamped.unsqueeze(1))
        else:
            k_prev_global = compute_laplace(prev_ptr_clamped.unsqueeze(1))

        return (
            k_stack,
            push_prob.to(actions.dtype),
            pop_prob.to(actions.dtype),
            new_ptr_state_raw,  # Return Raw for Cache
            k_prev_global,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        past_key_values: Cache | None = None,
        use_cache: bool | None = False,
        output_attentions: bool | None = False,
        **kwargs: Unpack[dict],
    ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
        if attention_mask is not None:
            assert len(attention_mask.shape) == 2, (
                "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
                "for padding purposes (0 indicating padding). "
                "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
            )

        batch_size, q_len, _ = hidden_states.shape
        # change to inference mode.
        mode = "fused_recurrent" if q_len <= 64 else self.mode

        last_state = None
        if past_key_values is not None and len(past_key_values) > self.layer_idx:
            last_state = past_key_values[self.layer_idx]

        recurrent_state, stack_recurrent_state, stack_ptr_state = None, None, None
        prev_stack_out = None
        if last_state is not None:
            recurrent_state, stack_recurrent_state, stack_ptr_state, prev_stack_out = (
                last_state["recurrent_state"]
            )

        cu_seqlens = kwargs.get("cu_seqlens")
        indices = None

        if attention_mask is not None:
            indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
            hidden_states = index_first_axis(
                rearrange(hidden_states, "b s ... -> (b s) ..."), indices
            ).unsqueeze(0)

        # 4. Pointer Generation
        # hidden_states is [1, Total, D] (if unpadded) or [B, T, D]
        action_logits = self.action_proj(hidden_states)
        actions = rearrange(
            action_logits, "b t (h d) -> b t h d", h=self.num_heads, d=3
        )

        k_stack, push_prob, pop_prob, new_ptr_state, k_prev_global = (
            self._generate_soft_pointer(
                actions, prev_ptr=stack_ptr_state, cu_seqlens=cu_seqlens
            )
        )

        if self.use_short_conv:
            conv_state_q, conv_state_k, conv_state_v = None, None, None
            if last_state is not None:
                conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"]
            q, conv_state_q = self.q_conv1d(
                x=self.q_proj(hidden_states),
                cache=conv_state_q,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
            )
            k, conv_state_k = self.k_conv1d(
                x=self.k_proj(hidden_states),
                cache=conv_state_k,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
            )
            v, conv_state_v = self.v_conv1d(
                x=self.v_proj(hidden_states),
                cache=conv_state_v,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
            )
        else:
            q = self.q_proj(hidden_states)
            k = self.k_proj(hidden_states)
            if self.qk_activation == "silu":
                q, k = F.silu(q), F.silu(k)
            v = F.silu(self.v_proj(hidden_states))

        q, k = map(
            lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k)
        )
        v = rearrange(v, "... (h d) -> ... h d", d=self.head_v_dim)
        if self.qk_activation != "silu":
            if self.qk_activation == "relu":
                q, k = q.relu(), k.relu()
            elif self.qk_activation == "elu":
                q, k = elu_p1(q), elu_p1(k)
            elif self.qk_activation != "identity":
                raise NotImplementedError

        if self.qk_norm == "sum":
            q = sum_norm(q).to(q)
            k = sum_norm(k).to(k)

        if self.use_beta:
            beta = self.b_proj(hidden_states).sigmoid()
        else:
            beta = torch.ones_like(q[..., 0])

        if self.allow_neg_eigval:
            beta = beta * 2.0

        # 6. Interleaving
        k_stack = k_stack.to(dtype=v.dtype)

        # Construct k_pop
        if cu_seqlens is not None:
            # Flattened Mode: Shift and patch boundaries
            k_shifted = torch.roll(k_stack, shifts=1, dims=1)
            start_indices = cu_seqlens[:-1]
            k_prev_flat = k_prev_global.squeeze(1)  # [B, H, S]
            k_shifted[0, start_indices] = k_prev_flat
            k_pop = k_shifted
        else:
            # Batch Mode
            if k_stack.shape[1] > 1:
                k_pop = torch.cat([k_prev_global, k_stack[:, :-1]], dim=1)
            else:
                k_pop = k_prev_global

        v_pop = torch.zeros_like(v)
        beta_pop = pop_prob.unsqueeze(-1)
        k_push = k_stack
        v_push = v
        beta_push = push_prob.unsqueeze(-1)

        def interleave(a, b):
            combined = torch.stack([a, b], dim=2)
            return combined.flatten(1, 2)

        # Force Contiguous!
        k_interleaved = interleave(k_pop, k_push).contiguous()
        v_interleaved = interleave(v_pop, v_push).contiguous()
        beta_interleaved = interleave(beta_pop, beta_push).squeeze(-1).contiguous()

        cu_seqlens_2x = None
        if cu_seqlens is not None:
            cu_seqlens_2x = cu_seqlens * 2

        if mode == "chunk":
            stack_out_2x, new_stack_recurrent_state = chunk_delta_rule(
                q=k_interleaved,
                k=k_interleaved,
                v=v_interleaved,
                beta=beta_interleaved,
                initial_state=stack_recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens_2x,
                use_qk_l2norm_in_kernel=True,
            )
        elif mode == "fused_recurrent":
            stack_out_2x, new_stack_recurrent_state = fused_recurrent_delta_rule(
                q=k_interleaved,
                k=k_interleaved,
                v=v_interleaved,
                beta=beta_interleaved,
                initial_state=stack_recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens_2x,
                use_qk_l2norm_in_kernel=True,
            )
        else:
            raise NotImplementedError(f"Not supported mode `{mode}`.")

        stack_out = stack_out_2x[:, 1::2].contiguous()

        # === 8. TIME-SHIFT LOGIC ===
        stack_read = None

        # CASE 1: Generation / Inference (Single Step)
        if q_len == 1:
            # If we have history in cache, use it
            if prev_stack_out is not None:
                if indices is not None:
                    # Unpadded Generation: We must gather the specific active batches from the cache.
                    # prev_stack_out is [B, 1, H, D]. indices map to active batches.

                    # 1. Flatten cache to [B, H, D]
                    cache_flat = prev_stack_out.squeeze(1)

                    # 2. Gather active batches: [Active_Count, H, D]
                    # 'indices' contains the batch indices that are currently active (mask=1)
                    active_cache = cache_flat[indices]

                    # 3. Reshape to [1, Active_Count, H, D] to match 'stack_out' shape
                    stack_read = active_cache.unsqueeze(0)
                else:
                    # Standard Batch Generation: Direct copy
                    stack_read = prev_stack_out
            else:
                # First token of generation (or use_cache=False): Zero init
                stack_read = torch.zeros_like(stack_out)

        # CASE 2: Training / Prefill with Variable Length (Flattened)
        elif cu_seqlens is not None:
            # We are in a long sequence (T > 1). We use the Roll logic.
            # Shift everything right by 1
            stack_read = torch.roll(stack_out, shifts=1, dims=1)

            # Zero out the start of each document to prevent bleeding
            start_indices = cu_seqlens[:-1]
            stack_read[:, start_indices] = 0

        # CASE 3: Standard Training / Prefill (Batch Mode)
        else:
            # Shift and zero-init the first time step
            zeros = torch.zeros_like(stack_out[:, :1])
            stack_read = torch.cat([zeros, stack_out[:, :-1]], dim=1)

        v_aug = v - stack_read  # minimize square(Sk + stack_v - orig_v)

        # Determine value to cache for next step
        # If flattened, we need the last token of each segment?
        # Actually, for Inference (Recurrent), inputs are never flattened (batch size B, len 1).
        # Flattening only happens in Training.
        # So for caching, we only care about the Batch Mode logic.
        # next_prev_stack_out = stack_out

        if mode == "fused_recurrent":
            o, recurrent_state = fused_recurrent_delta_rule(
                q=q,
                k=k,
                v=v_aug,
                beta=beta,
                initial_state=recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
                use_qk_l2norm_in_kernel=(self.qk_norm == "l2"),
            )
        elif mode == "chunk":
            o, recurrent_state = chunk_delta_rule(
                q=q,
                k=k,
                v=v_aug,
                beta=beta,
                initial_state=recurrent_state,
                output_final_state=use_cache,
                cu_seqlens=cu_seqlens,
                use_qk_l2norm_in_kernel=(self.qk_norm == "l2"),
            )
        else:
            raise NotImplementedError(f"Not supported mode `{mode}`.")

        next_prev_stack_out = None
        if use_cache:
            if cu_seqlens is not None:
                # FLATTENED MODE:
                # stack_out is [1, Total, H, D].
                # We need to extract the last token of EACH document.
                # The end index of doc[i] is cu_seqlens[i+1] - 1.

                # cu_seqlens is [0, len1, len1+len2, ...]
                # last_indices are [len1-1, len1+len2-1, ...]
                last_indices = cu_seqlens[1:] - 1

                # Gather: [1, Total, H, D] -> [B, H, D]
                # Note: stack_out is 1 in dim 0, so we use index 0.
                next_prev_stack_out = stack_out[0, last_indices]

                # Reshape to [B, 1, H, D] for consistency with inference
                next_prev_stack_out = next_prev_stack_out.unsqueeze(1)
            else:
                # BATCH MODE:
                # stack_out is [B, T, H, D]. Just take the last step.
                next_prev_stack_out = stack_out[:, -1:]

        if past_key_values is not None:
            # OPTIMIZATION: Only store the LAST step's output.
            # In inference (T=1), this is the same as stack_out.
            # In prefill (T=2048), this discards the history we don't need for the next step.
            past_key_values.update(
                recurrent_state=(
                    recurrent_state,
                    new_stack_recurrent_state,
                    new_ptr_state,
                    next_prev_stack_out,
                ),
                conv_state=(
                    (conv_state_q, conv_state_k, conv_state_v)
                    if self.use_short_conv
                    else None
                ),
                layer_idx=self.layer_idx,
                offset=q_len,
            )

        if self.use_gate:
            g = rearrange(
                self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim
            )
            o = self.o_norm(o, g)
        else:
            o = self.o_norm(o)
        o = rearrange(o, "b t h d -> b t (h d)")
        o = self.o_proj(o)
        if attention_mask is not None:
            o = pad_input(o.squeeze(0), indices, batch_size, q_len)

        return o, None, past_key_values

    def set_stack_size(self, new_size: int):
        """Updates the virtual stack size for evaluation."""
        self.orig_stack_size = self.stack_size
        self.stack_size = new_size

    def reset_stack_size(self):
        """Updates the virtual stack size for evaluation."""
        self.stack_size = self.orig_stack_size
