"""
Gaussian State Space Model (Gauss SSM) - Kalman Filter for Sequence Modeling

This is the main Gauss SSM implementation for the gauss-mad framework.
A feature-complete Bayesian state space model with:
- Log-space associative scan for numerically stable Kalman filtering
- Discretization via Ornstein-Uhlenbeck (OU) method
- MIMO (multi-input-multi-output) expansion via rank projections
- Token-dependent observation matrices (selective_h)

REQUIREMENTS:
- PyTorch >= 2.8 (for torch._higher_order_ops.associative_scan)

USAGE:
    # Standalone module (copy this file to your project)
    from gauss import GaussBlock

    layer = GaussBlock(dim=128, d_state=16)
    output = layer(x)  # x: (B, L, dim) -> output: (B, L, dim)

    # With uncertainty quantification
    layer = GaussBlock(dim=128, d_state=16, return_variance=True)
    output, variance = layer(x)  # extras accessible via layer._last_extras

KEY ARCHITECTURE CONCEPT - The q_projected Contraction:
    The SSM internally operates in EXPANDED space (d_model * d_state dimensions)
    but must return to d_model dimensions for residual connections.
    This is done via learned "q_projected" weights - see GaussBlock.forward() for details.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
import math

# Import PyTorch's higher order ops for associative_scan (required)
try:
    from torch import _higher_order_ops as hops
    TORCH_2_8_AVAILABLE = True
except ImportError:
    TORCH_2_8_AVAILABLE = False
    raise ImportError(
        "torch._higher_order_ops (PyTorch 2.8+) is required for GaussBlock. "
        "Please upgrade PyTorch to version 2.8 or higher."
    )

# FLA import for L2 norm (with fallback for standalone usage)
try:
    from fla.modules.l2norm import l2_norm as l2_norm_fn
except ImportError:
    try:
        from fla.modules.l2norm import l2_norm_fn
    except ImportError:
        # Fallback: simple L2 normalization when FLA not available
        def l2_norm_fn(x):
            return F.normalize(x, p=2, dim=-1)


class GaussSSM(nn.Module):
    """
    Core Gaussian State Space Model with embedded TorchLogScan filtering.

    Features:
    - Discretization (OU method, inline)
    - A parameterization (log-space for stability, inline)
    - Process noise (always exists, discretized if needed)
    - Feature expansion (outer product, inline)
    - Token-dependent h and delta support
    """
    def __init__(
        self,
        dim,
        d_state,
        discretize=True,
        a_parameterization='log_space',
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        process_noise_scale=0.01,
        selective_h=True,
        mimo_rank=1,
        use_b_bias=False,
        clip_value=None,              # Clip process_noise to max value
        *args, **kwargs
    ):
        super().__init__()
        self.d_model = dim
        self.d_state = d_state
        self.discretize = discretize
        self.a_parameterization = a_parameterization
        self.mimo_rank = mimo_rank
        self.clip_value = clip_value
        self.gauss_visualization = kwargs.get('gauss_visualization', False)

        # Device tracking buffer for proper multi-GPU support
        self.register_buffer('_device_indicator', torch.zeros(1))

        # A parameterization (inline - no factory)
        if a_parameterization == 'log_space':
            # Log-space: A = -exp(A_log) for stability
            lambda_init = torch.ones(self.d_model*self.d_state)
            self.lambda_log = nn.Parameter(torch.log(lambda_init))
            self.lambda_log._no_weight_decay = True
        else:
            # Direct parameterization
            self.A = nn.Parameter(torch.ones(self.d_model))
            self.A._no_weight_decay = True

        # h: observation matrix (base parameter, can be overridden by selective_h)
        if selective_h:
            self.h = None  # Handled externally
        else:
            self.h = nn.Parameter(torch.ones(self.d_state,self.mimo_rank))

        # Process noise: ALWAYS exists (fundamental to Kalman filter)
        if discretize:
            # Initialize for discretization (log-uniform in [dt_min, dt_max])
            #q_init = torch.rand(self.d_model) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
            #self.process_noise = nn.Parameter(torch.exp(q_init).clamp(min=dt_init_floor))
            self.process_noise = nn.Parameter(torch.ones(self.d_model) * 0.01)
        else:
            # Continuous-time: simple scaling
            self.process_noise = nn.Parameter(torch.ones(self.d_model) * process_noise_scale)

        # B bias: learnable bias added to phi and r via einsum with (1/var) or (mu/var)
        self.use_b_bias = use_b_bias
        if use_b_bias:
            # Shape: (d_model, d_state, mimo_rank) - element-wise mul with 1/var, sum over r
            self.b_bias = nn.Parameter(torch.zeros(self.d_model, self.d_state, self.mimo_rank))

        # Delta: only needed for discretization
        # Initialize at full expanded dimension (d_model * d_state)
        if discretize:
            dt = torch.rand(self.d_model * self.d_state) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
            inv_dt = self._inverse_softplus(torch.exp(dt).clamp(min=dt_init_floor))
            self.delta = nn.Parameter(inv_dt * 1.0)  # Scale factor = 1.0

        # Check PyTorch version for associative_scan
        if not TORCH_2_8_AVAILABLE:
            print("Warning: PyTorch 2.8 not available, using fallback prefix_scan")

    @property
    def device(self):
        """Get the device this model is currently on."""
        return self._device_indicator.device

    def _init_mean_cov(self, batch_size: int):
        """Initialize prior mean and covariance on model's device"""
        initial_mean = torch.zeros(batch_size, self.d_state, device=self.device)
        initial_cov = torch.ones(batch_size, self.d_state, device=self.device)
        return initial_mean, initial_cov

    def _compute_A(self):
        """Compute A from parameterization (inline - no factory)"""
        if self.a_parameterization == 'log_space':
            lambda_pos = torch.exp(self.lambda_log)
            return -lambda_pos
        else:
            return self.A

    def _discretize_ou(self, A, delta):
        """
        Discretize A using OU (Ornstein-Uhlenbeck) method.
        Inline implementation (no strategy factory).

        Continuous → Discrete: a = exp(-delta * |A|)
        """
        if A.ndim < delta.ndim:
            A = A[None, None, :]  # (d,) → (1, 1, d) → broadcasts
        return torch.exp(delta * A) # A_discrete = exp(A * delta) = exp(-lambda * delta)

    @staticmethod
    def _inverse_softplus(dt: torch.Tensor) -> torch.Tensor:
        """
        Compute inverse of softplus: x such that softplus(x) = dt.

        Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759

        Args:
            dt: Delta values (after softplus)

        Returns:
            inv_dt: Values to be used as parameters (before softplus)
        """
        return dt + torch.log(-torch.expm1(-dt))

    def _discretize_process_noise_ou(self, A, q_continuous, delta):
        """
        Discretize process noise using OU method.

        For OU process: Q_discrete = Q_continuous * (1 - exp(-2*delta*|A|)) / (2*|A|)
        Inline implementation (no strategy factory).
        """
        if A.ndim < delta.ndim:
            A = A[None, None, :]  # (d,) → (B, L, d)
        if q_continuous.ndim < delta.ndim:
            q_continuous = q_continuous[None, None, :]  # (d,) → (B, L, d)

        # Exact OU discretization for process noise
        lambda_positive = A.abs().clamp(min=1e-8)

        incr = -torch.expm1(-2.0 * lambda_positive * delta)
        return (q_continuous / (2.0 * lambda_positive)) * incr

    def _expand_outer_product(self, input, exp_factor):
        """
        Expand tensors using outer product for feature expansion.
        Args:
            input: Tensor of shape (B, L, d_model) or (d_model,)
            exp_factor: Expansion factor (d_state)
        """
        if input.ndim == 1:
            # Expand parameter vector
            d_model = input.shape[0]
            reshaped = input[:, None]  # (d_model, 1)
            expanded = reshaped.expand(-1, exp_factor).reshape(d_model*exp_factor)

        else:
            B, L, d_model = input.shape
            reshaped = input[:, :, :, None]  # (B, L, d_model, 1)
            expanded = reshaped.expand(-1, -1, -1, exp_factor).reshape(B, L, d_model * exp_factor)

        return expanded

    def ssm(
        self,
        mu: torch.Tensor,
        var: torch.Tensor,
        h: torch.Tensor,
        a: torch.Tensor,
        q: torch.Tensor,
        prior: bool = False,
        return_extras: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """
        Pure SSM filtering: Kalman filter on (μ, σ²) using log-space associative scan.
        Now supports MIMO expansion via matrix multiplication through rank dimension.

        Args:
            mu: Observation mean (B, L, d_model, r) - matrix form with rank
            var: Observation variance (B, L, d_model, r) - matrix form with rank
            h: Observation matrix (B, L, d_state, r) - matrix form with rank
            a: State transition (d_model*d_state,) - expanded via outer product
            q: Process noise (d_model*d_state,) - expanded via outer product
            prior: If True, return prior instead of posterior
            return_extras: If True, populate extras dict with internal states for visualization

        Returns:
            post_mean: Posterior mean (B, L, d_model*d_state)
            post_variance: Posterior variance (B, L, d_model*d_state)
            extras: Dictionary with additional info (populated if return_extras=True)
                - prior_mean: (B, L, dim) - one-step-ahead prediction mean
                - prior_variance: (B, L, dim) - one-step-ahead prediction variance
                - observation_influence: (B, L, dim) - phi = h^2/var
                - precision: (B, L, dim) - posterior precision (1/variance)
                - information_vector: (B, L, dim) - eta
                - a_effective: (B, L, dim) - effective discretized A
                - q_effective: (B, L, dim) - effective discretized Q
        """
        # Assert matrix forms with rank dimension
        assert mu.ndim == 4, f"mu must have 4 dims (B, L, d_model, r), got {mu.shape}"
        assert var.ndim == 4, f"var must have 4 dims (B, L, d_model, r), got {var.shape}"
        assert h.ndim == 4, f"h must have 4 dims (B, L, d_state, r), got {h.shape}"

        B, L, d_model, r = mu.shape
        _, _, d_state, _ = h.shape
        device = mu.device

        # Numerical safety
        EPS = 1e-12

        # Expand a, q to (B, L, d_model*d_state) if needed (they are already expanded via outer product)
        dim = d_model * d_state
        if a.ndim == 1:
            a = a[None, None, :].expand(B, L, -1)
        if q.ndim == 1:
            q = q[None, None, :].expand(B, L, -1)

        a2 = a.pow(2)

        # MIMO expansion: φ = (1/var) @ (h^2)^T
        # Shape: (B, L, d_model, r) @ (B, L, r, d_state) → (B, L, d_model, d_state)
        h_squared = h.pow(2)  # (B, L, d_state, r)
        inv_var = 1.0 / var  # (B, L, d_model, r)
        if self.use_b_bias:
            h_eff = h.unsqueeze(2) + self.b_bias.unsqueeze(0).unsqueeze(0)     # (B, L, d_model, d_state, r)
            h_eff_sq = h_eff.square().permute(0, 1, 2, 4, 3).contiguous()      # (B, L, d_model, r, d_state)

            phi = torch.matmul(
                inv_var.unsqueeze(-2),   # (B, L, d_model, 1, r)
                h_eff_sq                 # (B, L, d_model, r, d_state)
            ).squeeze(-2)                # (B, L, d_model, d_state)
            phi = phi.reshape(B, L, d_model * d_state).clamp(min=EPS)

        else:
            phi = torch.matmul(inv_var, h_squared.transpose(-2, -1))  # (B, L, d_model, d_state)
            phi = phi.reshape(B, L, d_model * d_state).clamp(min=EPS)
        # Build log-Möbius coefficients for continuous-time system
        logA = ((1/a2) * (1 + q * phi)).clamp(min=EPS).log()
        logB = phi.clamp(min=EPS).log()
        logC = ((1/a2) * q).clamp(min=EPS).log()
        logD = torch.ones(B, L, dim, device=device).clamp(min=EPS).log()

        # Pack into (B, L, 4, dim) for 2×2 block matrix representation
        logM = torch.stack([logA, logB, logC, logD], dim=2)

        # Define log-domain matrix composition for associative scan
        def log_matrix_compose(logX, logY):
            """
            Compose two log-domain 2×2 block matrices: log(Y @ X)
            logX, logY: (B, 4, dim)
            Returns: (B, 4, dim)
            """
            A1, B1, C1, D1 = logX.unbind(1)
            A2, B2, C2, D2 = logY.unbind(1)

            LA = torch.logaddexp(A2 + A1, B2 + C1)
            LB = torch.logaddexp(A2 + B1, B2 + D1)
            LC = torch.logaddexp(C2 + A1, D2 + C1)
            LD = torch.logaddexp(C2 + B1, D2 + D1)

            return torch.stack([LA, LB, LC, LD], dim=1)

        # Run associative scan for precision
        logM_prefix = hops.associative_scan(
            combine_fn=log_matrix_compose,
            xs=logM,
            dim=1,  # Scan along time dimension
            combine_mode="generic"
        )

        # Unpack prefix scan results
        logA_p, logB_p, logC_p, logD_p = logM_prefix.unbind(2)

        # Initialize prior precision in log-space
        prior_mean0, prior_cov0 = self._init_mean_cov(B)

        # Handle dimension mismatch if needed (when dim > d_state due to expansion)
        if prior_mean0.shape[-1] != dim:
            # Expand prior to match dimension
            expansion_ratio = dim // prior_mean0.shape[-1]
            prior_mean0 = prior_mean0.repeat_interleave(expansion_ratio, dim=-1)
            prior_cov0 = prior_cov0.repeat_interleave(expansion_ratio, dim=-1)

        log_lam0 = -prior_cov0.log()  # log(λ₀) where λ₀ = 1/cov₀ (B, dim)

        # Compute posterior precision in log-space for numerical stability
        log_num = torch.logaddexp(logA_p + log_lam0[:, None, :], logB_p)
        log_den = torch.logaddexp(logC_p + log_lam0[:, None, :], logD_p)

        log_post_prec = log_num - log_den
        post_prec = log_post_prec.exp()
        post_variance = (-log_post_prec).exp()

        # Posterior mean via information vector scan (scan-based, no FFT)
        def info_compose(X, Y):
            """Compose info vectors in real space: Y ∘ X

            The recurrence is: η_t = (1/a) * η_{t-1} + r_t
            In matrix form: [a_inv, r] where result is [a_inv_prod, eta]

            Note: r can be negative (since mu can be negative), so we use real-space
            instead of log-space to preserve the sign.
            """
            a1, r1 = X.unbind(1)
            a2, r2 = Y.unbind(1)
            # a_comp = a2 * a1 (product of 1/a terms)
            # r_comp = a2 * r1 + r2 (accumulated weighted sum)
            return torch.stack([a2 * a1, a2 * r1 + r2], dim=1)

        # Build information vector inputs (real-space, not log-space)
        a_inv = 1.0 / a.clamp(min=EPS)  # (B, L, dim), positive

        # MIMO expansion: r = (mu/var) @ h^T
        # Shape: (B, L, d_model, r) @ (B, L, r, d_state) → (B, L, d_model, d_state)
        # Note: r can be NEGATIVE since mu can be negative
        mu_over_var = mu / var  # (B, L, d_model, r)
        if self.use_b_bias:
            h_eff = h.unsqueeze(2) + self.b_bias.unsqueeze(0).unsqueeze(0)      # (B, L, d_model, d_state, r)
            h_eff_t = h_eff.permute(0, 1, 2, 4, 3).contiguous()                 # (B, L, d_model, r, d_state)

            r = torch.matmul(
                mu_over_var.unsqueeze(-2),   # (B, L, d_model, 1, r)
                h_eff_t                      # (B, L, d_model, r, d_state)
            ).squeeze(-2)
        else:
            r = torch.matmul(mu_over_var, h.transpose(-2, -1))  # (B, L, d_model, d_state)
        r = r.reshape(B, L, d_model * d_state)  # (B, L, d_model*d_state), can be negative

        # Pack into (B, L, 2, dim) for scan - real space, not log space
        M_eta = torch.stack([a_inv, r], dim=2)

        # Run information vector scan for mean (real-space)
        M_eta_prefix = hops.associative_scan(
            combine_fn=info_compose,
            xs=M_eta,
            dim=1,  # Scan along time dimension
            combine_mode="generic"
        )

        # Extract information vector η (already in real space, can be negative)
        eta = M_eta_prefix[:, :, 1, :]

        # Posterior mean: μ = η / λ
        post_mean = eta / post_prec

        # Prepare next prior
        next_prior_mean = a * post_mean
        next_prior_variance = a2 * post_variance

        # Always populate extras for visualization (no performance penalty since values are already computed)
        # Note: return_extras parameter kept for API compatibility but now ignored
        extras = {
            'prior_mean': next_prior_mean.detach(),
            'prior_variance': next_prior_variance.detach(),
            'observation_influence': phi.detach(),
            'precision': post_prec.detach(),
            'information_vector': eta.detach(),
            'a_effective': a.detach(),
            'q_effective': q.detach(),
        }

        if prior:
            return next_prior_mean, next_prior_variance, extras
        else:
            return post_mean, post_variance, extras

    def forward(
        self,
        latent_token: torch.Tensor,
        token_variance: torch.Tensor,
        h_projected: Optional[torch.Tensor] = None,
        delta_bias: Optional[torch.Tensor] = None
    ):
        """
        Forward pass with optional discretization and expansion.

        Args:
            latent_token: Latent tokens of shape (B, L, d_model*r)
            token_variance: Token variance of shape (B, L, d_model*r)
            h_projected: Optional token-dependent h (B, L, d_state*r)
            delta_bias: Optional token-dependent delta bias (B, L, 1)

        Returns:
            tuple: (post_mean, post_variance, extras)
        """
        B, L, _ = latent_token.shape
        device = latent_token.device

        # Step 1: Compute A
        A = self._compute_A()  # (d_model*d_state,)

        # Step 2: Process noise (ALWAYS exists)
        q_base = self.process_noise.clamp(min=1e-7)  # (d_model,)
        if self.clip_value is not None:
            q_base = q_base.clamp(max=self.clip_value)

        # Step 3: Reshape inputs to matrix form for MIMO
        # Input shapes: (B, L, d_model*r), (B, L, d_model*r), (B, L, d_state*r)
        # Output shapes: (B, L, d_model, r), (B, L, d_model, r), (B, L, d_state, r)
        latent_token = latent_token.reshape(B, L, self.d_model, self.mimo_rank)
        token_variance = token_variance.reshape(B, L, self.d_model, self.mimo_rank)

        if h_projected is not None:
            h = h_projected.reshape(B, L, self.d_state, self.mimo_rank)
        else:
            # Static h case
            h = self.h.reshape(1, 1, self.d_state, self.mimo_rank).expand(B, L, -1, -1)

        # Step 4: Expand only A and q_base via outer product
        # These don't participate in MIMO - they use standard outer product expansion
        #A = self._expand_outer_product(A, self.d_state)  # (d_model*d_state,)
        q_base = self._expand_outer_product(q_base, self.d_state)  # (d_model*d_state)


        # Step 5: Discretize if needed
        if self.discretize:
            # Delta is already at full expanded dimension (d_model*d_state)
            delta = self.delta

            # Add token-dependent bias if provided (Mamba-style)
            if delta_bias is not None:
                delta_bias = self._expand_outer_product(delta_bias, self.d_state)  # (B, L, d_model*d_state)
                delta = delta[None, None, :] + delta_bias  # (B, L, d_model*d_state)
            else:
                delta = delta[None, None, :].expand(B, L, -1)  # (B, L, d_model*d_state)

            delta = F.softplus(delta) + 1e-7  # Ensure positivity

            # Discretize A using OU method (inline)
            a = self._discretize_ou(A, delta)  # (B, L, d_model*d_state)

            # Discretize process noise using OU method (inline)
            q = self._discretize_process_noise_ou(A, q_base, delta)  # (B, L, d_model*d_state)
        else:
            # Continuous-time: use A and q directly
            a = A  # (d_model*d_state,)
            q = q_base  # (d_model*d_state,)

        # Step 6: Run Kalman filter (log-scan)
        post_mean, post_variance, extras = self.ssm(
            mu=latent_token,
            var=token_variance,
            h=h,
            a=a,
            q=q,
            return_extras=self.gauss_visualization  # Only compute phi etc when viz enabled
        )

        # Add SSM parameters to extras for visualization
        extras['a_effective'] = a.detach() if isinstance(a, torch.Tensor) else a
        extras['q_effective'] = q.detach() if isinstance(q, torch.Tensor) else q

        return post_mean, post_variance, extras


class GaussBlock(nn.Module):
    """
    Gaussian state space model block compatible with MAD framework.
    Feature-complete version with inline implementations for easy understanding.

    Compatible with mad-lab's LanguageModel:
    - Takes input (B, L, dim)
    - Returns output (B, L, dim) - same shape for residual connection
    - No mixing logic needed (handled by LanguageModel)
    """
    def __init__(
        self,
        dim,                          # Required by mad-lab
        d_state=16,                   # State dimension
        discretize=True,              # Enable discretization
        a_parameterization='log_space', # A parameterization
        selective_h=True,            # Token-dependent h (observation matrix)
        selective_process=False,      # Token-dependent process noise (future use)
        selective_delta=False,        # Token-dependent delta (Mamba-style)
        h_feature_map_mode=None,      # Feature map mode for h expansion
        gating_mode='multiplicative', # 'additive' or 'multiplicative'
        mlp_hidden_units=None,        # MLP before SSM
        use_causal_conv=False,        # Causal conv1d
        conv_kernel_size=3,
        conv_activation='silu',
        conv_groups=None,
        use_gating=False,             # Enable gating
        gating_activation='silu',
        use_fused_norm_gate=False,    # Use FusedRMSNormSwishGate (store flag, not implemented)
        use_lambda_skip=False,        # Lambda skip connection
        lambda_skip_mode='scalar',    # 'scalar' or 'vector' (per-channel like Mamba D)
        lambda_init=-1.0,             # Initialization value for lambda_param (default -1.0)
        process_noise_scale=0.01,
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        bias=False,
        return_variance=False,        # Return (mean, var) tuple if True
        mimo_rank=1,                  # Rank for h/q projections
        use_qk_norm=False,            # QK normalization for H and Q
        use_c_bias=False,             # Learnable bias for C (q_projected), init=1
        use_b_bias=False,             # Learnable bias for B (h_projected), init=1
        compile_ssm=True,             # Compile SSM with torch.compile
        clip_value=None,              # Clip h, q, variance, process_noise to max value
        *args, **kwargs               # Catches max_length, etc.
    ):
        super().__init__()
        self.d_model = dim
        self.d_state = d_state if d_state is not None else dim
        self.return_variance = return_variance
        self.gating_mode = gating_mode
        self.selective_h = selective_h
        self.selective_process = selective_process  # Stored for future use
        self.selective_delta = selective_delta
        self.mimo_rank = mimo_rank
        self.clip_value = clip_value
        self.use_fused_norm_gate = use_fused_norm_gate  # Stored but not implemented (FLA only)

        # Device tracking buffer for proper multi-GPU support
        self.register_buffer('_device_indicator', torch.zeros(1))

        # Visualization state storage (accessible even when return_variance=False)
        self._last_variance = None
        self._last_extras = None

        # Gauss visualization flag (enables return_extras in SSM for phi, etc.)
        self.gauss_visualization = kwargs.get('gauss_visualization', False)

        # 1. Input projection (like Mamba's in_proj)
        if use_gating:
            self.in_proj = nn.Linear(dim, self.d_model * 2, bias=bias)
        else:
            self.in_proj = nn.Linear(dim, self.d_model, bias=bias)

        # 2. MLP removed for alignment with gauss_core
        # Parameter kept for backward compatibility, but not used
        self.mlp = None  # Always None - MLP not supported in aligned version

        # 3. Optional causal conv1d (like Mamba)
        self.use_causal_conv = use_causal_conv
        if use_causal_conv:
            self.conv1d = nn.Conv1d(
                in_channels=self.d_model,
                out_channels=self.d_model,
                kernel_size=conv_kernel_size,
                groups=conv_groups if conv_groups is not None else self.d_model,
                padding=conv_kernel_size - 1,
                bias=True
            )
            self.conv_activation = conv_activation

        # 4. Sensor projection (replaces feature map infrastructure)
        # Output dimension: 2*mimo_rank*d_model + mimo_rank*d_state + d_state
        # Split into: [latent_token, variance, h, q] each with appropriate dimensions
        sensor_output_dim = 2 * self.mimo_rank * self.d_model + self.mimo_rank * self.d_state + self.d_state
        self.sensor_proj = nn.Linear(self.d_model, sensor_output_dim, bias=bias)

        # 4b. QK Norm for h and q projections (using FLA's l2_norm)
        self.use_qk_norm = use_qk_norm
        # No learnable parameters needed - l2_norm_fn is stateless

        # 4c. Learnable biases with init=1
        # Note: use_b_bias is now handled inside GaussSSM with shape (d_model, d_state)
        self.use_c_bias = use_c_bias
        if use_c_bias:
            self.c_bias = nn.Parameter(torch.ones(self.d_state))

        # 5. Selective delta projection (Mamba-style)
        if selective_delta and discretize:
            self.delta_proj = nn.Linear(self.d_model, 1, bias=False)
        else:
            self.delta_proj = None

        # 8. Gating setup
        self.use_gating = use_gating
        if use_gating:
            if gating_activation == 'silu':
                self.gate_act = nn.SiLU()
            elif gating_activation == 'gelu':
                self.gate_act = nn.GELU()
            elif gating_activation == 'relu':
                self.gate_act = nn.ReLU()
            else:
                self.gate_act = nn.SiLU()  # Default

        # 9. Lambda skip connection
        self.lambda_skip_mode = lambda_skip_mode
        if use_lambda_skip:
            if lambda_skip_mode == 'vector':
                # Per-channel lambda (like Mamba's D vector)
                self.lambda_param = nn.Parameter(torch.full((self.d_model,), lambda_init))
            else:  # 'scalar'
                self.lambda_param = nn.Parameter(torch.tensor(lambda_init))
        else:
            self.lambda_param = None

        # 10. Core SSM
        self.ssm = GaussSSM(
            dim=self.d_model,
            d_state=self.d_state,
            discretize=discretize,
            a_parameterization=a_parameterization,
            dt_min=dt_min,
            dt_max=dt_max,
            dt_init_floor=dt_init_floor,
            process_noise_scale=process_noise_scale,
            selective_h=selective_h,  # FIXED: was selective_h_q
            mimo_rank=self.mimo_rank,  # Pass mimo_rank to SSM
            use_b_bias=use_b_bias,  # Pass b_bias flag to SSM
            clip_value=self.clip_value,  # Pass clip_value to SSM for process_noise clipping
            gauss_visualization=self.gauss_visualization  # Enable return_extras when viz on
        )

        # Optionally compile the SSM for optimal performance
        if compile_ssm:
            try:
                self.ssm = torch.compile(self.ssm)
            except Exception as e:
                print(f"Warning: torch.compile failed for GaussSSM: {e}")


    @property
    def device(self):
        """Get the device this model is currently on."""
        return self._device_indicator.device

    def forward(self, x, *args, **kwargs):
        """
        Forward pass compatible with mad-lab framework.

        Args:
            x: (B, L, dim)

        Returns:
            output: (B, L, dim) - if return_variance=False (default, mad-lab compatible)
            (output, variance): tuple - if return_variance=True (for uncertainty)
        """
        B, L, D = x.shape

        # 1. Input projection
        projected = self.in_proj(x)
        if self.use_gating:
            z, gate = projected.chunk(2, dim=-1)  # Each: (B, L, d_model)
        else:
            z = projected  # (B, L, d_model)
            gate = None

        # 2. MLP removed for alignment with gauss_core (self.mlp is always None)

        # 3. Optional causal conv1d
        if self.use_causal_conv:
            z_conv = z.transpose(-1, -2)  # (B, d_model, L)
            z_conv = self.conv1d(z_conv)
            # Remove extra padding to make it causal
            z_conv = z_conv[..., :L]
            # Apply configured activation
            if self.conv_activation == 'silu':
                z_conv = F.silu(z_conv)
            elif self.conv_activation == 'relu':
                z_conv = F.relu(z_conv)
            elif self.conv_activation == 'gelu':
                z_conv = F.gelu(z_conv)
            else:
                z_conv = F.silu(z_conv)  # default
            z = z_conv.transpose(-1, -2)  # back to (B, L, d_model)

        # Store residual for lambda skip
        residual = z

        # 4. Sensor projection with simple 4-way split
        sensor_projected = self.sensor_proj(z)  # (B, L, 2*mimo_rank*(d_model+d_state))

        # Split into: [latent_token, variance, h, q]
        # Each component has dimension: mimo_rank * component_dim
        latent_token, log_var, h_projected, q_projected = torch.split(
            sensor_projected,
            [self.mimo_rank * self.d_model,
             self.mimo_rank * self.d_model,
             self.mimo_rank * self.d_state,
             self.d_state],
            dim=-1
        )

        # Apply L2 QK norm if enabled (like GatedDeltaNet)
        # For mimo_rank > 1, normalize per-rank (not jointly across all ranks)
        if self.use_qk_norm:
            # h_projected: (B, L, mimo_rank * d_state) -> per-rank normalization
            h_proj_reshape = h_projected.view(B, L, self.mimo_rank, self.d_state)
            h_proj_reshape = l2_norm_fn(h_proj_reshape)  # normalizes last dim (d_state) per rank
            h_projected = h_proj_reshape.view(B, L, self.mimo_rank * self.d_state).to(latent_token)
            # q_projected: (B, L, d_state) - normalize directly
            q_projected = l2_norm_fn(q_projected).to(latent_token)

        # Apply clipping to h_projected and q_projected if enabled
        if self.clip_value is not None:
            h_projected = h_projected.clamp(-self.clip_value, self.clip_value)
            q_projected = q_projected.clamp(-self.clip_value, self.clip_value)

        # Apply c_bias after norm (if enabled)
        # Note: b_bias is now applied inside GaussSSM to phi and r with shape (d_model, d_state)
        if self.use_c_bias:
            q_projected = q_projected + self.c_bias

        # Ensure positive variance using softplus
        token_variance = F.softplus(log_var) + 1e-6

        # Clip token_variance if enabled (only upper bound since variance is positive)
        if self.clip_value is not None:
            token_variance = token_variance.clamp(max=self.clip_value)

        # 5. Selective delta (token-dependent time scale)
        delta_bias = None
        if self.delta_proj is not None:
            delta_bias = self.delta_proj(z)  # (B, L, 1)

        # 6. GaussSSM filtering (operates in d_model space with internal expansion)
        post_mean, post_variance, extras = self.ssm(latent_token, token_variance, h_projected, delta_bias)
        # post_mean: (B, L, d_model*d_state) where d_state is expansion factor

        # Add observation-side quantities to extras for visualization
        extras['token_variance'] = token_variance.detach()  # (B, L, d_model, r) - observation noise R
        extras['h_projected'] = h_projected.detach()        # (B, L, d_state, r) - observation matrix h
        extras['h_effective'] = h_projected.detach()        # Alias for visualization compatibility
        extras['q_projected'] = q_projected.detach()        # (B, L, d_state) - contraction weights for d_state aggregation
        extras['latent_token'] = latent_token.detach()      # (B, L, d_model, r) - observation mean (projected x)

        # =======================================================================
        # CRITICAL: ENCODER/DECODER CONTRACTION MECHANISM
        # =======================================================================
        # The SSM internally operates in EXPANDED space: (B, L, d_model * d_state)
        # where d_state is the expansion factor (default: 16).
        #
        # To return to original dimension for residual connections, we CONTRACT
        # using learned q_projected weights:
        #
        #   post_mean:    (B, L, d_model, d_state) - Kalman posterior mean
        #   q_projected:  (B, L, d_state)          - learned combination weights
        #
        #   y = post_mean @ q_projected.unsqueeze(-1)
        #     = (B, L, d_model, d_state) @ (B, L, d_state, 1)
        #     = (B, L, d_model, 1) → squeeze → (B, L, d_model)
        #
        # This is analogous to the C matrix in standard SSMs: y = Cx
        # The weights are LEARNED and token-position-aware (shape includes B, L).
        #
        # For variance: combines quadratically: var_out = sum(q² * var_expanded)
        # This is the correct propagation rule for linear combinations of
        # independent random variables.
        #
        # WHY THIS MATTERS FOR STANDALONE USAGE:
        # If you copy just this module, the q_projected weights are part of
        # sensor_proj and are learned during training. They determine how the
        # d_state latent dimensions are combined into the output.
        # =======================================================================

        # SSM output is (B, L, d_model*d_state), need to contract to (B, L, d_model)
        post_mean = post_mean.view(B, L, self.d_model, self.d_state)  # (B, L, d_model, d_state)
        post_variance = post_variance.view(B, L, self.d_model, self.d_state)  # (B, L, d_model, d_state)
        # Combine using matmul: (B, L, d_model, d_state) @ (B, L, d_state, 1) -> (B, L, d_model, 1) -> squeeze
        y = torch.matmul(post_mean, q_projected.unsqueeze(-1)).squeeze(-1)
        y_var = torch.matmul(post_variance, q_projected.pow(2).unsqueeze(-1)).squeeze(-1)  # Variance combines quadratically

        # 8. Store expanded posterior for attention computation & contract prior_variance for ratio
        extras['post_variance_expanded'] = post_variance.detach()  # (B, L, d_model, d_state)
        if 'prior_variance' in extras:
            # Contract prior_variance to match y_var shape for ratio computation
            pv = extras['prior_variance'].view(B, L, self.d_model, self.d_state)
            extras['prior_variance'] = torch.matmul(pv, q_projected.pow(2).unsqueeze(-1)).squeeze(-1).detach()

        # 9. Reshape a_effective and q_effective for visualization
        # Handle both discretized (3D) and continuous (1D) cases
        if 'a_effective' in extras:
            a_eff = extras['a_effective']
            if a_eff.ndim == 1:
                # Continuous case: (d_model*d_state,) -> expand to (B, L, d_model, d_state)
                extras['a_effective'] = a_eff.view(1, 1, self.d_model, self.d_state).expand(B, L, -1, -1).contiguous().detach()
            else:
                # Discretized case: (B, L, d_model*d_state) -> reshape to (B, L, d_model, d_state)
                extras['a_effective'] = a_eff.view(B, L, self.d_model, self.d_state).detach()

        if 'q_effective' in extras:
            q_eff = extras['q_effective']
            if q_eff.ndim == 1:
                # Continuous case: (d_model*d_state,) -> expand to (B, L, d_model, d_state)
                extras['q_effective'] = q_eff.view(1, 1, self.d_model, self.d_state).expand(B, L, -1, -1).contiguous().detach()
            else:
                # Discretized case: (B, L, d_model*d_state) -> reshape to (B, L, d_model, d_state)
                extras['q_effective'] = q_eff.view(B, L, self.d_model, self.d_state).detach()

        # 10. Lambda skip connection
        if self.lambda_param is not None:
            y = y + self.lambda_param * residual

        # 11. Gating mechanism
        if self.use_gating and gate is not None:
            if self.gating_mode == 'additive':
                # Additive gating: add gated signal (residual-style)
                y = y + self.gate_act(gate)
            else:  # multiplicative (default)
                # Multiplicative gating: element-wise multiplication (Mamba-style)
                y = y  * self.gate_act(gate)

        # Store states for visualization (accessible even when return_variance=False)
        self._last_variance = y_var.detach()
        self._last_extras = extras

        # Return based on mad-lab compatibility flag
        if self.return_variance:
            return y, y_var  # extras accessible via self._last_extras
        else:
            return y  # Mad-lab compatible (like Mamba, Attention, etc.)
