"""
GatedDeltaNet for MAD-lab framework.
Based on NVIDIA's "Gated Delta Networks: Improving Mamba2 with Delta Rule" (ICLR '25)
Adapted for MAD-lab layer interface.
"""
from __future__ import annotations
from typing import Optional, Tuple
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.modules.activations import ACT2FN
from .ops.gated_delta_rule_ops import chunk_gated_delta_rule

try:
    from fla.modules.l2norm import l2_norm as l2_norm_fn
except ImportError:
    from fla.modules.l2norm import l2_norm_fn


class GatedDeltaNet(nn.Module):
    """
    Gated DeltaNet layer compatible with MAD-lab framework.

    Key features:
    - Memory gate (gk): Token-dependent exponential decay
    - Update gate (beta): Controls write strength
    - Delta rule: Subtracts stored content before writing
    - L2 normalized Q/K
    - Short convolutions on Q, K, V

    Architecture:
        Input → Q,K,V projections → ShortConv1d → L2Norm →
        GatedDeltaRule(q,k,v,beta,gk) → OutputGating → Output
    """
    def __init__(
        self,
        dim: int,                        # MAD-lab interface: hidden dimension
        expand_k: float = 0.75,          # Key expansion factor
        expand_v: float = 1.5,           # Value expansion factor
        num_heads: int = 4,              # Number of attention heads
        num_kv_heads: Optional[int] = None,
        qk_norm: str = 'l2',             # QK normalization: 'l2', 'softmax', 'longhorn'
        conv_size: int = 4,              # Short conv kernel size
        conv_bias: bool = False,
        gate_fn: str = 'swish',
        elementwise_affine: bool = True,
        norm_eps: float = 1e-5,
        gate_logit_normalizer: int = 16,
        fuse_norm: bool = True,
        use_mamba_gate: bool = True,     # Mamba-style decay gate
        use_residual: bool = False,      # Mamba2-style skip connection
        use_input_gate: bool = False,
        layer_idx: int = None,
        *args, **kwargs                  # Absorb extra params from config
    ):
        super().__init__()
        self.d_model = dim
        self.hidden_size = dim
        self.qk_norm = qk_norm
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        self.num_kv_groups = self.num_heads // self.num_kv_heads
        self.expand_k = expand_k
        self.expand_v = expand_v
        self.conv_size = conv_size
        self.layer_idx = layer_idx

        self.key_dim = int(dim * expand_k)
        self.value_dim = int(dim * expand_v)
        self.key_dim_per_group = self.key_dim // self.num_kv_groups
        self.value_dim_per_group = self.value_dim // self.num_kv_groups
        self.head_qk_dim = self.key_dim // num_heads
        self.head_v_dim = self.value_dim // num_heads

        # Projections
        self.q_proj = nn.Linear(dim, self.key_dim, bias=False)
        self.k_proj = nn.Linear(dim, self.key_dim, bias=False)
        self.v_proj = nn.Linear(dim, self.value_dim_per_group, bias=False)
        self.g_proj = nn.Linear(dim, self.value_dim, bias=False)

        # Short convolutions
        self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_norm != 'softmax' else None)
        self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu' if qk_norm != 'softmax' else None)
        self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')

        # Gates
        self.gk_proj = nn.Linear(dim, self.num_heads, bias=not use_mamba_gate)
        self.b_proj = nn.Linear(dim, self.num_heads, bias=True)

        # Output gating
        if gate_fn == 'swish' and fuse_norm:
            self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
            self.fuse_norm_and_gate = True
        else:
            self.fuse_norm_and_gate = False
            self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
            self.gate_fn = ACT2FN[gate_fn]

        self.o_proj = nn.Linear(self.value_dim, dim, bias=False)
        self.gate_logit_normalizer = gate_logit_normalizer

        # Mamba-style gate parameters
        self.use_mamba_gate = use_mamba_gate
        if use_mamba_gate:
            A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
            self.A_log = nn.Parameter(torch.log(A))
            self.A_log._no_weight_decay = True
            self.D = nn.Parameter(torch.ones(self.num_heads))
            self.D._no_weight_decay = True
            dt_min, dt_max, dt_init_floor = 0.001, 0.1, 1e-4
            dt = torch.exp(torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
            dt = torch.clamp(dt, min=dt_init_floor)
            inv_dt = dt + torch.log(-torch.expm1(-dt))
            self.dt_bias = nn.Parameter(inv_dt)
            self.dt_bias._no_weight_decay = True

        self.use_residual = use_residual
        if use_residual:
            self.D = nn.Parameter(torch.ones(self.num_heads))
            self.D._no_weight_decay = True
        self.use_input_gate = use_input_gate

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Forward pass compatible with MAD-lab.

        Args:
            x: Input tensor (B, L, dim)

        Returns:
            Output tensor (B, L, dim)
        """
        # Projections
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Short convolutions (FLA returns tuple: (output, final_state))
        q, _ = self.q_conv1d(q)
        k, _ = self.k_conv1d(k)
        v, _ = self.v_conv1d(v)

        # Memory gate (gk) - controls exponential decay
        gk = self.gk_proj(x).float()
        if self.use_mamba_gate:
            gk = -self.A_log.float().exp() * F.softplus(gk + self.dt_bias)
        else:
            gk = F.logsigmoid(gk) / self.gate_logit_normalizer
        gk = gk.transpose(1, 2)

        # Update gate (beta) - controls write strength
        beta = self.b_proj(x).float().sigmoid()
        beta = beta.transpose(1, 2)

        # Reshape for multi-head
        q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)
        if self.num_kv_groups > 1:
            k, v = (repeat(t, 'b l (h d) -> b (h g) l d', h=self.num_kv_heads, g=self.num_kv_groups) for t in (k, v))
        else:
            k, v = (rearrange(t, 'b l (h d) -> b h l d', h=self.num_kv_heads) for t in (k, v))

        # QK normalization
        if self.qk_norm == 'l2':
            q = l2_norm_fn(q).to(v)
            k = l2_norm_fn(k).to(v)
        elif self.qk_norm == 'softmax':
            k = k.softmax(dim=-1).to(v)
            q = q.softmax(dim=-1).to(v)
        elif self.qk_norm == 'longhorn':
            beta = beta / (1 + beta * (k * k).sum(-1))

        if self.use_input_gate:
            v = (v * (1 - gk.float().exp())[..., None]).to(v.dtype)

        # Core gated delta rule
        # Use BT=32 to reduce shared memory usage (default BT=64 exceeds some GPUs' limits)
        o, _ = chunk_gated_delta_rule(q, k, v, beta, gk, BT=32, initial_state=None, output_final_state=False)

        # Residual connection (Mamba2-style)
        if self.use_residual:
            o = o + self.D[None, :, None, None] * v

        o = rearrange(o, 'b h l d -> b l h d')

        # Output gating
        g = self.g_proj(x)
        if self.fuse_norm_and_gate:
            g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
            o = self.g_norm_swish_gate(o, g)
            o = rearrange(o, 'b l h d -> b l (h d)')
        else:
            o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
            o = o * self.gate_fn(g)

        return self.o_proj(o)
