"""
Kalman Linear Attention (KLA) Block.

A sequence modeling layer that implements Kalman filtering with an associative scan,
enabling parallel computation over sequences while maintaining principled uncertainty
quantification.

Key features:
- Full process noise parameterization at d_model*d_state dimension
- Corrected information vector (eta) update using proper Kalman gain
- Log-space associative scan for numerical stability
- OU discretization for continuous-to-discrete conversion
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
import math
import warnings

try:
    from torch import _higher_order_ops as hops

    TORCH_2_8_AVAILABLE = True
except ImportError:
    TORCH_2_8_AVAILABLE = False
    raise ImportError(
        "torch._higher_order_ops (PyTorch 2.8+) is required for KLABlock. "
        "Please upgrade PyTorch to version 2.8 or higher."
    )

try:
    from fla.modules.l2norm import l2_norm as l2_norm_fn
except ImportError:
    from fla.modules.l2norm import l2_norm_fn

# from ..ops.log_mobius_scan import log_mobius_scan
# from scan import mamba_scan as linear_scan


def log_mobius_scan(logA, logB, logC, logD):
    seq_len = logA.shape[1]

    logA = logA.transpose(0, 1).contiguous()
    logB = logB.transpose(0, 1).contiguous()
    logC = logC.transpose(0, 1).contiguous()
    logD = logD.transpose(0, 1).contiguous()

    preds = []

    curr_logA = torch.zeros_like(logA[0])  # log(1)
    curr_logB = torch.full_like(logB[0], -float("inf"))  # log(0)
    curr_logC = torch.full_like(logC[0], -float("inf"))
    curr_logD = torch.zeros_like(logD[0])

    for t in range(len(logA)):
        # New Matrix M_t
        a_t, b_t, c_t, d_t = logA[t], logB[t], logC[t], logD[t]

        next_logA = torch.logaddexp(a_t + curr_logA, b_t + curr_logC)
        next_logB = torch.logaddexp(a_t + curr_logB, b_t + curr_logD)
        next_logC = torch.logaddexp(c_t + curr_logA, d_t + curr_logC)
        next_logD = torch.logaddexp(c_t + curr_logB, d_t + curr_logD)

        curr_logA, curr_logB, curr_logC, curr_logD = (
            next_logA,
            next_logB,
            next_logC,
            next_logD,
        )

        Lambda_out = torch.logaddexp(curr_logA, curr_logB) - torch.logaddexp(
            curr_logC, curr_logD
        )
        preds.append(Lambda_out)

    return torch.stack(preds, dim=1)


def linear_scan(alpha, r):
    batch_size, seq_len, qk_dim, v_dim = alpha.shape
    alpha = alpha.transpose(0, 1).contiguous()
    r = r.transpose(0, 1).contiguous()

    hs = []
    h = torch.zeros( batch_size, qk_dim, v_dim, device=alpha.device)
    for t in range(seq_len):
        h = alpha[t] * h + r[t]
        hs.append(h)
    return torch.stack(hs, dim=1)


class KalmanFilter(nn.Module):
    """
    Core Kalman filter module with log-space associative scan.

    Implements the Kalman filtering equations using a parallel scan formulation,
    with corrected information vector update:
        α_t = a_t / (a_t² + q_t · λ_{t-1}^+)
    """

    def __init__(
        self,
        dim,
        d_state,
        discretize=True,
        a_parameterization="log_space",
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        process_noise_scale=0.01,
        selective_h=True,
        mimo_rank=1,
        use_b_bias=False,
        clip_value=None,
        compile_ssm=False,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.d_model = dim
        self.d_state = d_state
        self.discretize = discretize
        self.a_parameterization = a_parameterization
        self.mimo_rank = mimo_rank
        self.clip_value = clip_value
        self.gauss_visualization = kwargs.get("gauss_visualization", False)
        self.compile_ssm = compile_ssm

        self.register_buffer("_device_indicator", torch.zeros(1))

        # A parameterization - full d_model*d_state
        if a_parameterization == "log_space":
            lambda_init = torch.ones(self.d_model * self.d_state)
            self.lambda_log = nn.Parameter(torch.log(lambda_init))
            self.lambda_log._no_weight_decay = True
        else:
            self.A = nn.Parameter(torch.ones(self.d_model))
            self.A._no_weight_decay = True

        # h: observation matrix
        if selective_h:
            self.h = None
        else:
            self.h = nn.Parameter(torch.ones(self.d_state, self.mimo_rank))

        # === FULL PROCESS NOISE ===
        # Parameterized at full d_model * d_state dimension (no outer product expansion)
        self.process_noise = nn.Parameter(
            torch.ones(self.d_model * self.d_state) * process_noise_scale
        )

        # B bias
        self.use_b_bias = use_b_bias
        if use_b_bias:
            self.b_bias = nn.Parameter(
                torch.zeros(self.d_model, self.d_state, self.mimo_rank)
            )

        # Delta - full d_model*d_state
        if discretize:
            dt = torch.rand(self.d_model * self.d_state) * (
                math.log(dt_max) - math.log(dt_min)
            ) + math.log(dt_min)
            inv_dt = self._inverse_softplus(torch.exp(dt).clamp(min=dt_init_floor))
            self.delta = nn.Parameter(inv_dt)

        if not TORCH_2_8_AVAILABLE:
            print("Warning: PyTorch 2.8 not available, using fallback prefix_scan")

        # Disable compilation of ssm method if requested
        if not compile_ssm:
            self.ssm = torch.compiler.disable(self.ssm)

    @property
    def device(self):
        return self._device_indicator.device

    def _init_mean_cov(self, batch_size: int):
        initial_mean = torch.zeros(batch_size, self.d_state, device=self.device)
        initial_cov = torch.ones(batch_size, self.d_state, device=self.device)
        return initial_mean, initial_cov

    def _compute_A(self):
        if self.a_parameterization == "log_space":
            lambda_pos = torch.exp(self.lambda_log)
            return -lambda_pos
        else:
            return self.A

    def _discretize_ou(self, A, delta):
        if A.ndim < delta.ndim:
            A = A[None, None, :]
        return torch.exp(delta * A)

    @staticmethod
    def _inverse_softplus(dt: torch.Tensor) -> torch.Tensor:
        return dt + torch.log(-torch.expm1(-dt))

    def _discretize_process_noise_ou(self, A, q_continuous, delta):
        if A.ndim < delta.ndim:
            A = A[None, None, :]
        if q_continuous.ndim < delta.ndim:
            q_continuous = q_continuous[None, None, :]
        lambda_positive = A.abs().clamp(min=1e-8)
        incr = -torch.expm1(-2.0 * lambda_positive * delta)
        return (q_continuous / (2.0 * lambda_positive)) * incr

    def _expand_outer_product(self, input, exp_factor):
        if input.ndim == 1:
            d_model = input.shape[0]
            reshaped = input[:, None]
            expanded = reshaped.expand(-1, exp_factor).reshape(d_model * exp_factor)
        else:
            B, L, d_model = input.shape
            reshaped = input[:, :, :, None]
            expanded = reshaped.expand(-1, -1, -1, exp_factor).reshape(
                B, L, d_model * exp_factor
            )
        return expanded

    def ssm(
        self,
        mu: torch.Tensor,
        var: torch.Tensor,
        h: torch.Tensor,
        a: torch.Tensor,
        q: torch.Tensor,
        prior: bool = False,
        return_extras: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """Pure SSM filtering with log-space associative scan and corrected eta update.
        
        Note on precision: The associative scans (log_mobius_scan, linear_scan) require
        fp32 for numerical stability due to log/exp operations. All other operations
        (matmuls, projections) can remain in the autocast dtype (bf16/fp16).
        """
        assert mu.ndim == 4
        assert var.ndim == 4
        assert h.ndim == 4

        B, L, d_model, r = mu.shape
        _, _, d_state, _ = h.shape
        device = mu.device
        input_dtype = mu.dtype  # Store original dtype for restoration
        EPS = 1e-12

        dim = d_model * d_state
        if a.ndim == 1:
            a = a.view(1, 1, d_model, d_state).expand(B, L, -1, -1)
        elif a.ndim == 3:
            a = a.view(B, L, d_model, d_state)
        if q.ndim == 1:
            q = q.view(1, 1, d_model, d_state).expand(B, L, -1, -1)
        elif q.ndim == 3:
            q = q.view(B, L, d_model, d_state)

        a2 = a.pow(2)

        h_squared = h.pow(2)
        inv_var = 1.0 / var
        if self.use_b_bias:
            h_eff = h.unsqueeze(2) + self.b_bias.unsqueeze(0).unsqueeze(0)
            h_eff_sq = h_eff.square().permute(0, 1, 2, 4, 3).contiguous()
            phi = torch.matmul(inv_var.unsqueeze(-2), h_eff_sq).squeeze(-2)
            phi = phi.clamp(min=EPS)
        else:
            phi = torch.matmul(inv_var, h_squared.transpose(-2, -1))
            phi = phi.clamp(min=EPS)

        # ========== LOG-MOBIUS SCAN (requires fp32 for numerical stability) ==========
        # Cast to fp32 only for the scan operations which involve log/exp
        logA = ((1 / a2) * (1 + q * phi)).clamp(min=EPS).log().float()
        logB = phi.clamp(min=EPS).log().float()
        logC = ((1 / a2) * q).clamp(min=EPS).log().float()
        logD = torch.zeros(B, L, d_model, d_state, device=device, dtype=torch.float32)

        # Run log-mobius scan in fp32
        log_post_prec = log_mobius_scan(logA, logB, logC, logD)
        del logA, logB, logC, logD
        post_prec = log_post_prec.exp()
        post_variance = (-log_post_prec).exp()
        del log_post_prec
        # ========== END LOG-MOBIUS SCAN ==========

        # ========== CORRECTED INFORMATION VECTOR UPDATE ==========

        # λ₀ = initial precision = 1/prior_cov0
        # NOTE: log_lam0 = 0, so lam0 = 1
        lam0 = torch.ones(B, 1, d_model, d_state, device=device, dtype=torch.float32)

        # λ_{t-1}^+ for each step t (shifted): [λ0, λ1^+, λ2^+, ..., λ_{L-1}^+]
        prev_post_prec = torch.cat([lam0, post_prec[:, :-1, :, :]], dim=1)
        del lam0

        # α_t = a_t / (a_t² + q_t · λ_{t-1}^+)
        # Keep in fp32 since prev_post_prec is fp32 and this feeds into linear_scan
        a_f32 = a.float()
        a2_f32 = a2.float()
        q_f32 = q.float()
        den = (a2_f32 + q_f32 * prev_post_prec).clamp(min=EPS)
        alpha = (a_f32 / den).clamp(min=EPS)
        del den, a_f32, a2_f32, q_f32

        # r_t = h · μ_obs / σ²_obs (information contribution from observation)
        # These matmuls can stay in input dtype (bf16), then cast for scan
        mu_over_var = mu / var
        if self.use_b_bias:
            h_eff = h.unsqueeze(2) + self.b_bias.unsqueeze(0).unsqueeze(0)
            h_eff_t = h_eff.permute(0, 1, 2, 4, 3).contiguous()
            r = torch.matmul(mu_over_var.unsqueeze(-2), h_eff_t).squeeze(-2)
        else:
            r = torch.matmul(mu_over_var, h.transpose(-2, -1))

        # ========== LINEAR SCAN (requires fp32 for numerical stability) ==========
        # linear_scan computes h_t = alpha_t * h_{t-1} + r_t
        # Cast r to fp32 for the scan
        r_f32 = r.float()
        eta = linear_scan(alpha, r_f32)
        del alpha, r_f32
        # ========== END LINEAR SCAN ==========

        post_mean = eta / post_prec
        
        # Cast results back to input dtype for downstream operations
        post_mean = post_mean.to(input_dtype)
        post_variance = post_variance.to(input_dtype)

        # ========== END CORRECTED UPDATE ==========

        # Use original dtype tensors for downstream
        next_prior_mean = a * post_mean
        next_prior_variance = a2 * post_variance

        extras = dict()

        if prior:
            return next_prior_mean, next_prior_variance, extras
        else:
            return post_mean, post_variance, extras

    def forward(
        self,
        latent_token: torch.Tensor,
        token_variance: torch.Tensor,
        h_projected: Optional[torch.Tensor] = None,
        delta_bias: Optional[torch.Tensor] = None,
    ):
        B, L, _ = latent_token.shape

        # Step 1: Compute A (already full d_model*d_state)
        A = self._compute_A()

        # Step 2: Process noise - FULL, no expansion needed
        # Already (d_model*d_state,)
        q_base = self.process_noise.clamp(min=1e-7)
        if self.clip_value is not None:
            q_base = q_base.clamp(max=self.clip_value)

        # Step 3: Reshape inputs
        latent_token = latent_token.reshape(B, L, self.d_model, self.mimo_rank)
        token_variance = token_variance.reshape(B, L, self.d_model, self.mimo_rank)

        if h_projected is not None:
            h = h_projected.reshape(B, L, self.d_state, self.mimo_rank)
        else:
            h = self.h.reshape(1, 1, self.d_state, self.mimo_rank).expand(B, L, -1, -1)

        # Step 4: NO expansion for q_base - it's already full (d_model*d_state,)

        # Step 5: Discretize if needed (OU method for both A and Q)
        if self.discretize:
            delta = self.delta
            if delta_bias is not None:
                delta_bias = self._expand_outer_product(delta_bias, self.d_state)
                delta = delta[None, None, :] + delta_bias
            else:
                delta = delta[None, None, :].expand(B, L, -1)

            delta = F.softplus(delta) + 1e-7

            # Discretize A
            a = self._discretize_ou(A, delta)

            # Discretize Q using OU method (q_base is already full)
            q = self._discretize_process_noise_ou(A, q_base, delta)
        else:
            a = A
            q = q_base

        # Step 6: Run Kalman filter
        post_mean, post_variance, extras = self.ssm(
            mu=latent_token,
            var=token_variance,
            h=h,
            a=a,
            q=q,
            return_extras=self.gauss_visualization,
        )

        return post_mean, post_variance, extras


class KLABlock(nn.Module):
    """
    Kalman Linear Attention (KLA) Block.

    A sequence modeling layer combining:
    - Input projection with optional gating
    - Causal convolution
    - Kalman filter (KalmanFilter) for temporal modeling
    - Output projection with skip connection
    """

    def __init__(
        self,
        dim,
        d_state=16,
        discretize=True,
        a_parameterization="log_space",
        selective_h=True,
        selective_process=False,
        selective_delta=False,
        h_feature_map_mode='repeat',
        gating_mode="multiplicative",
        mlp_hidden_units=None,
        use_causal_conv=True,
        conv_kernel_size=4,
        conv_activation="silu",
        conv_groups=None,
        use_gating=True,
        gating_activation="silu",
        use_fused_norm_gate=False,
        use_lambda_skip=True,
        lambda_skip_mode="scalar",
        lambda_init=-1.0,
        process_noise_scale=0.01,
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        bias=False,
        return_variance=False,
        mimo_rank=1,
        use_qk_norm=True,
        use_c_bias=False,
        use_b_bias=False,
        compile_ssm=False,
        clip_value=None,
        scale_variance_through_gating=True,
        use_mamba_parameterization=False,
        expand=1,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.d_model = dim
        self.d_state = d_state if d_state is not None else dim
        self.return_variance = return_variance
        self.gating_mode = gating_mode
        self.selective_h = selective_h
        self.scale_variance_through_gating = scale_variance_through_gating
        self.selective_process = selective_process
        self.selective_delta = selective_delta
        self.mimo_rank = mimo_rank
        self.clip_value = clip_value
        self.use_fused_norm_gate = use_fused_norm_gate
        self.use_mamba_parameterization = use_mamba_parameterization

        self.register_buffer("_device_indicator", torch.zeros(1))
        self._last_variance = None
        self._last_extras = None
        self.gauss_visualization = False

        # Mamba-style expansion (if enabled)
        if use_mamba_parameterization:
            self.expand = expand
            self.d_inner = int(self.expand * self.d_model)
        else:
            self.d_inner = self.d_model

        # Input projection
        if use_gating:
            self.in_proj = nn.Linear(dim, self.d_inner * 2, bias=bias)
        else:
            self.in_proj = nn.Linear(dim, self.d_inner, bias=bias)

        self.mlp = None

        # Causal conv1d
        self.use_causal_conv = use_causal_conv
        if use_causal_conv:
            # Mamba uses depthwise convolution (groups=d_inner)
            # Assert/set conv_groups appropriately
            if use_mamba_parameterization:
                if not (conv_groups is None or conv_groups == self.d_inner):
                    warnings.warn(f"Mamba parameterization requires conv_groups=None or {self.d_inner}, got {conv_groups}", UserWarning)
                conv_groups = self.d_inner
            else:
                if conv_groups is None:
                    conv_groups = self.d_inner  # Default to depthwise
            
            self.conv1d = nn.Conv1d(
                in_channels=self.d_inner,
                out_channels=self.d_inner,
                kernel_size=conv_kernel_size,
                groups=conv_groups,
                padding=conv_kernel_size - 1,
                bias=True,
            )
            self.conv_activation = conv_activation

        # Sensor projection
        sensor_output_dim = (
            2 * self.mimo_rank * self.d_inner
            + self.mimo_rank * self.d_state
            + self.d_state
        )
        self.sensor_proj = nn.Linear(self.d_inner, sensor_output_dim, bias=bias)

        self.use_qk_norm = use_qk_norm
        self.use_c_bias = use_c_bias
        if use_c_bias:
            self.c_bias = nn.Parameter(torch.ones(self.d_state))

        # Selective delta
        if selective_delta and discretize:
            self.delta_proj = nn.Linear(self.d_inner, 1, bias=False)
        else:
            self.delta_proj = None

        # Gating
        self.use_gating = use_gating
        if use_gating:
            if gating_activation == "silu":
                self.gate_act = nn.SiLU()
            elif gating_activation == "gelu":
                self.gate_act = nn.GELU()
            elif gating_activation == "relu":
                self.gate_act = nn.ReLU()
            else:
                self.gate_act = nn.SiLU()

        # Skip connection (D for Mamba mode, lambda for KLA mode)
        self.lambda_skip_mode = lambda_skip_mode
        if use_mamba_parameterization:
            if use_lambda_skip:
                
                warnings.warn(
                    "use_lambda_skip=True is ignored when use_mamba_parameterization=True. "
                    "Using Mamba's D skip parameter instead.",
                    UserWarning
                )
            # Mamba uses a D skip parameter (d_inner dimension, init to 1s)
            self.D = nn.Parameter(torch.ones(self.d_inner))
            self.D._no_weight_decay = True
            self.lambda_param = None
        elif use_lambda_skip:
            # KLA lambda skip (can be scalar or vector, different init values)
            if lambda_skip_mode == "vector":
                self.lambda_param = nn.Parameter(
                    torch.full((self.d_inner,), lambda_init)
                )
            else:
                self.lambda_param = nn.Parameter(torch.tensor(lambda_init))
            self.D = None
        else:
            self.lambda_param = None
            self.D = None

        # Output projection (for Mamba parameterization: d_inner -> d_model)
        if use_mamba_parameterization:
            self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
        else:
            self.out_proj = None

        # Core Kalman filter
        self.ssm = KalmanFilter(
            dim=self.d_inner,
            d_state=self.d_state,
            discretize=discretize,
            a_parameterization=a_parameterization,
            dt_min=dt_min,
            dt_max=dt_max,
            dt_init_floor=dt_init_floor,
            process_noise_scale=process_noise_scale,
            selective_h=selective_h,
            mimo_rank=self.mimo_rank,
            use_b_bias=use_b_bias,
            clip_value=self.clip_value,
            compile_ssm=compile_ssm,
            gauss_visualization=self.gauss_visualization,
        )

        if compile_ssm:
            try:
                self.ssm = torch.compile(self.ssm)
            except Exception as e:
                print(f"Warning: torch.compile failed for KalmanFilter: {e}")

    @property
    def device(self):
        return self._device_indicator.device

    def forward(self, x, *args, **kwargs):
        B, L, D = x.shape

        # Input projection
        projected = self.in_proj(x)
        if self.use_gating:
            z, gate = projected.chunk(2, dim=-1)
        else:
            z = projected
            gate = None

        # Causal conv1d
        if self.use_causal_conv:
            z_conv = z.transpose(-1, -2)
            z_conv = self.conv1d(z_conv)
            z_conv = z_conv[..., :L]
            if self.conv_activation == "silu":
                z_conv = F.silu(z_conv)
            elif self.conv_activation == "relu":
                z_conv = F.relu(z_conv)
            elif self.conv_activation == "gelu":
                z_conv = F.gelu(z_conv)
            else:
                z_conv = F.silu(z_conv)
            z = z_conv.transpose(-1, -2)

        residual = z

        # Sensor projection
        sensor_projected = self.sensor_proj(z)
        latent_token, log_var, h_projected, q_projected = torch.split(
            sensor_projected,
            [
                self.mimo_rank * self.d_inner,
                self.mimo_rank * self.d_inner,
                self.mimo_rank * self.d_state,
                self.d_state,
            ],
            dim=-1,
        )

        # QK norm
        if self.use_qk_norm:
            h_proj_reshape = h_projected.view(B, L, self.mimo_rank, self.d_state)
            h_proj_reshape = l2_norm_fn(h_proj_reshape)
            h_projected = h_proj_reshape.view(B, L, self.mimo_rank * self.d_state).to(
                latent_token
            )
            q_projected = l2_norm_fn(q_projected).to(latent_token)

        if self.clip_value is not None:
            h_projected = h_projected.clamp(-self.clip_value, self.clip_value)
            q_projected = q_projected.clamp(-self.clip_value, self.clip_value)

        if self.use_c_bias:
            q_projected = q_projected + self.c_bias

        token_variance = F.softplus(log_var) + 1e-6
        if self.clip_value is not None:
            token_variance = token_variance.clamp(max=self.clip_value)

        # Selective delta
        delta_bias = None
        if self.delta_proj is not None:
            delta_bias = self.delta_proj(z)

        # SSM filtering
        post_mean, post_variance, extras = self.ssm(
            latent_token, token_variance, h_projected, delta_bias
        )

        # Contract to output dimension
        post_mean = post_mean.view(B, L, self.d_inner, self.d_state)
        post_variance = post_variance.view(B, L, self.d_inner, self.d_state)
        y = torch.matmul(post_mean, q_projected.unsqueeze(-1)).squeeze(-1)
        if self.return_variance:
            y_var = torch.matmul(post_variance, q_projected.pow(2).unsqueeze(-1)).squeeze(
                -1
            )



        # Skip connections
        if self.D is not None:
            # Mamba-style D skip: y = ssm_out + D * residual (where residual is post-conv z)
            # D is applied to the same input that goes into SSM (after conv, before sensor proj)
            y = y + self.D * residual
        elif self.lambda_param is not None:
            # Lambda skip (original KLA)
            y = y + self.lambda_param * residual

        # Gating
        if self.use_gating and gate is not None:
            if self.gating_mode == "additive":
                y = y + self.gate_act(gate)
            else:
                g = self.gate_act(gate)
                y = y * g
                if self.scale_variance_through_gating and self.return_variance:
                    y_var = y_var * g.pow(2)

        # Output projection (Mamba-style: project d_inner back to d_model)
        if self.out_proj is not None:
            y = self.out_proj(y)
            # Project variance through output layer if needed
            # For linear layer: Var(Wx) ≈ W * Var(x) (element-wise, approximate for diagonal)
            if self.return_variance:
                y_var = self.out_proj(y_var)

        # self._last_variance = y_var.detach()
        # self._last_extras = extras

        if self.return_variance:
            return y, y_var
        else:
            return y


# Backward compatibility aliases
GaussSSMFullCorrected = KalmanFilter
GaussBlockFullCorrected = KLABlock
GaussSSMFull = KalmanFilter
GaussBlockFull = KLABlock
