"""
Kalman Linear Attention (KLA) Block.

A sequence modeling layer that implements Kalman filtering with an associative scan,
enabling parallel computation over sequences while maintaining principled uncertainty
quantification.

Key features:
- Full process noise parameterization at d_model*d_state dimension
- Corrected information vector (eta) update using proper Kalman gain
- Log-space associative scan for numerical stability
- OU discretization for continuous-to-discrete conversion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
import math

from kernels import log_mobius_scan

try:
    from torch import _higher_order_ops as hops
    TORCH_2_8_AVAILABLE = True
except ImportError:
    TORCH_2_8_AVAILABLE = False
    raise ImportError(
        "torch._higher_order_ops (PyTorch 2.8+) is required for KLABlock. "
        "Please upgrade PyTorch to version 2.8 or higher."
    )

try:
    from fla.modules.l2norm import l2_norm as l2_norm_fn
except ImportError:
    from fla.modules.l2norm import l2_norm_fn


class KalmanFilter(nn.Module):
    """
    Core Kalman filter module with log-space associative scan.

    Implements the Kalman filtering equations using a parallel scan formulation,
    with corrected information vector update:
        α_t = a_t / (a_t² + q_t · λ_{t-1}^+)
    """
    def __init__(
        self,
        dim,
        d_state,
        discretize=True,
        a_parameterization='log_space',
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        process_noise_scale=0.01,
        selective_h=True,
        mimo_rank=1,
        use_b_bias=False,
        clip_value=None,
        *args, **kwargs
    ):
        super().__init__()
        self.d_model = dim
        self.d_state = d_state
        self.discretize = discretize
        self.a_parameterization = a_parameterization
        self.mimo_rank = mimo_rank
        self.clip_value = clip_value
        self.gauss_visualization = kwargs.get('gauss_visualization', False)

        self.register_buffer('_device_indicator', torch.zeros(1))

        # A parameterization - full d_model*d_state
        if a_parameterization == 'log_space':
            lambda_init = torch.ones(self.d_model * self.d_state)
            self.lambda_log = nn.Parameter(torch.log(lambda_init))
            self.lambda_log._no_weight_decay = True
        else:
            self.A = nn.Parameter(torch.ones(self.d_model))
            self.A._no_weight_decay = True

        # h: observation matrix
        if selective_h:
            self.h = None
        else:
            self.h = nn.Parameter(torch.ones(self.d_state, self.mimo_rank))

        # === FULL PROCESS NOISE ===
        # Parameterized at full d_model * d_state dimension (no outer product expansion)
        self.process_noise = nn.Parameter(torch.ones(self.d_model * self.d_state) * process_noise_scale)

        # B bias
        self.use_b_bias = use_b_bias
        if use_b_bias:
            self.b_bias = nn.Parameter(torch.zeros(self.d_model, self.d_state, self.mimo_rank))

        # Delta - full d_model*d_state
        if discretize:
            dt = torch.rand(self.d_model * self.d_state) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
            inv_dt = self._inverse_softplus(torch.exp(dt).clamp(min=dt_init_floor))
            self.delta = nn.Parameter(inv_dt)

        if not TORCH_2_8_AVAILABLE:
            print("Warning: PyTorch 2.8 not available, using fallback prefix_scan")

    @property
    def device(self):
        return self._device_indicator.device

    def _init_mean_cov(self, batch_size: int):
        initial_mean = torch.zeros(batch_size, self.d_state, device=self.device)
        initial_cov = torch.ones(batch_size, self.d_state, device=self.device)
        return initial_mean, initial_cov

    def _compute_A(self):
        if self.a_parameterization == 'log_space':
            lambda_pos = torch.exp(self.lambda_log)
            return -lambda_pos
        else:
            return self.A

    def _discretize_ou(self, A, delta):
        if A.ndim < delta.ndim:
            A = A[None, None, :]
        return torch.exp(delta * A)

    @staticmethod
    def _inverse_softplus(dt: torch.Tensor) -> torch.Tensor:
        return dt + torch.log(-torch.expm1(-dt))

    def _discretize_process_noise_ou(self, A, q_continuous, delta):
        if A.ndim < delta.ndim:
            A = A[None, None, :]
        if q_continuous.ndim < delta.ndim:
            q_continuous = q_continuous[None, None, :]
        lambda_positive = A.abs().clamp(min=1e-8)
        incr = -torch.expm1(-2.0 * lambda_positive * delta)
        return (q_continuous / (2.0 * lambda_positive)) * incr

    def _expand_outer_product(self, input, exp_factor):
        if input.ndim == 1:
            d_model = input.shape[0]
            reshaped = input[:, None]
            expanded = reshaped.expand(-1, exp_factor).reshape(d_model * exp_factor)
        else:
            B, L, d_model = input.shape
            reshaped = input[:, :, :, None]
            expanded = reshaped.expand(-1, -1, -1, exp_factor).reshape(B, L, d_model * exp_factor)
        return expanded

    def ssm(
        self,
        mu: torch.Tensor,
        var: torch.Tensor,
        h: torch.Tensor,
        a: torch.Tensor,
        q: torch.Tensor,
        prior: bool = False,
        return_extras: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """Pure SSM filtering with log-space associative scan and corrected eta update."""
        assert mu.ndim == 4
        assert var.ndim == 4
        assert h.ndim == 4

        B, L, d_model, r = mu.shape
        _, _, d_state, _ = h.shape
        device = mu.device
        EPS = 1e-12

        dim = d_model * d_state
        if a.ndim == 1:
            a = a[None, None, :].expand(B, L, -1)  # (B, L, dim)
        if q.ndim == 1:
            q = q[None, None, :].expand(B, L, -1)  # (B, L, dim)

        a2 = a.pow(2)  # (B, L, dim)

        h_squared = h.pow(2)  # (B, L, d_state, r)
        inv_var = 1.0 / var  # (B, L, d_model, r)
        if self.use_b_bias:
            h_eff = h.unsqueeze(2) + self.b_bias.unsqueeze(0).unsqueeze(0)  # (B, L, d_state, r)
            h_eff_sq = h_eff.square().permute(0, 1, 2, 4, 3).contiguous()  # (B, L, d_state, r, 1)
            phi = torch.matmul(inv_var.unsqueeze(-2), h_eff_sq).squeeze(-2)  # (B, L, d_model, d_state)
            phi = phi.reshape(B, L, d_model * d_state).clamp(min=EPS)  # (B, L, dim)
        else:
            phi = torch.matmul(inv_var, h_squared.transpose(-2, -1))  # (B, L, d_model, d_state)
            phi = phi.reshape(B, L, d_model * d_state).clamp(min=EPS)  # (B, L, dim)

        logA = ((1/a2) * (1 + q * phi)).clamp(min=EPS).log()  # (B, L, dim)
        logB = phi.clamp(min=EPS).log()  # (B, L, dim)
        logC = ((1/a2) * q).clamp(min=EPS).log()  # (B, L, dim)
        logD = torch.ones(B, L, dim, device=device).clamp(min=EPS).log()  # (B, L, dim)
        
        ### old method

        logM = torch.stack([logA, logB, logC, logD], dim=2)

        def log_matrix_compose(logX, logY):
            A1, B1, C1, D1 = logX.unbind(1)
            A2, B2, C2, D2 = logY.unbind(1)
            LA = torch.logaddexp(A2 + A1, B2 + C1)
            LB = torch.logaddexp(A2 + B1, B2 + D1)
            LC = torch.logaddexp(C2 + A1, D2 + C1)
            LD = torch.logaddexp(C2 + B1, D2 + D1)
            return torch.stack([LA, LB, LC, LD], dim=1)

        logM_prefix = hops.associative_scan(  # (B, L, 4, dim)
            combine_fn=log_matrix_compose,
            xs=logM,
            dim=1,
            combine_mode="generic"
        )
        
        logA_p, logB_p, logC_p, logD_p = logM_prefix.unbind(2)  # (B, L, dim)

        prior_mean0, prior_cov0 = self._init_mean_cov(B)  # (B, d_state), (B, d_state)
        if prior_mean0.shape[-1] != dim:
            expansion_ratio = dim // prior_mean0.shape[-1]
            prior_mean0 = prior_mean0.repeat_interleave(expansion_ratio, dim=-1)
            prior_cov0 = prior_cov0.repeat_interleave(expansion_ratio, dim=-1)

        log_lam0 = -prior_cov0.log()  # (B, dim)
        log_num = torch.logaddexp(logA_p + log_lam0[:, None, :], logB_p)  # (B, L, dim)
        log_den = torch.logaddexp(logC_p + log_lam0[:, None, :], logD_p)  # (B, L, dim)
        log_post_prec = log_num - log_den
        
        
        # ### new method
        
        # # log mobius scan implementation expects inputs in shape (B, V, Q, L)
        # # so we need to reshape and permute the inputs accordingly
        
        # logB = logB.view(B, L, d_model, d_state).contiguous()  # (B, L, d_model, d_state)
        # logC = logC.view(B, L, d_model, d_state).contiguous()  # (B, L, d_model, d_state)
        # logD = logD.view(B, L, d_model, d_state).contiguous()  # (B, L, d_model, d_state)
        # logA = logA.view(B, L, d_model, d_state).contiguous()  # (B, L, d_model, d_state)

        #  # Perform log-mobius scan
         
        # log_post_prec = log_mobius_scan(  # (B, L, d_model, d_state)
        #     logA,
        #     logB,
        #     logC,
        #     logD,
        # )
        
        # # flatten back to (B, L, dim)
        # log_post_prec = log_post_prec.view(B, L, dim).contiguous()
        
        post_prec = log_post_prec.exp()
        post_variance = (-log_post_prec).exp()

        # ========== CORRECTED INFORMATION VECTOR UPDATE ==========

        # λ₀ = initial precision = 1/prior_cov0
        lam0 = log_lam0.exp()  # (B, dim)

        # λ_{t-1}^+ for each step t (shifted): [λ0, λ1^+, λ2^+, ..., λ_{L-1}^+]
        prev_post_prec = torch.cat([lam0[:, None, :], post_prec[:, :-1, :]], dim=1)  # (B, L, dim)

        # α_t = a_t / (a_t² + q_t · λ_{t-1}^+)
        den = (a2 + q * prev_post_prec).clamp(min=EPS)
        alpha = (a / den).clamp(min=EPS)

        def info_compose(X, Y):
            # affine composition: (α1, β1) then (α2, β2) => (α2·α1, α2·β1 + β2)
            a1, b1 = X.unbind(1)
            a2, b2 = Y.unbind(1)
            return torch.stack([a2 * a1, a2 * b1 + b2], dim=1)

        # r_t = h · μ_obs / σ²_obs (information contribution from observation)
        mu_over_var = mu / var
        if self.use_b_bias:
            h_eff = h.unsqueeze(2) + self.b_bias.unsqueeze(0).unsqueeze(0)
            h_eff_t = h_eff.permute(0, 1, 2, 4, 3).contiguous()
            r = torch.matmul(mu_over_var.unsqueeze(-2), h_eff_t).squeeze(-2)
        else:
            r = torch.matmul(mu_over_var, h.transpose(-2, -1))
        r = r.reshape(B, L, d_model * d_state)

        # Stack (α, r) for affine scan
        M_eta = torch.stack([alpha, r], dim=2)

        M_eta_prefix = hops.associative_scan(  # (B, L, 2, dim)
            combine_fn=info_compose,
            xs=M_eta,
            dim=1,
            combine_mode="generic"
        )

        # Extract prefix products
        alpha_prefix = M_eta_prefix[:, :, 0, :]  # cumulative α product (B, L, dim)
        beta_prefix = M_eta_prefix[:, :, 1, :]   # accumulated r terms (B, L, dim)

        # Initial information: η₀ = μ₀ / σ²₀ = prior_mean / prior_cov
        eta0 = (prior_mean0 / prior_cov0)  # (B, dim)

        # Final information: η_t = α_prefix · η₀ + β_prefix
        eta = alpha_prefix * eta0[:, None, :] + beta_prefix  # (B, L, dim)

        post_mean = eta / post_prec  # (B, L, dim)

        # ========== END CORRECTED UPDATE ==========

        next_prior_mean = a * post_mean
        next_prior_variance = a2 * post_variance

        extras = {
            'prior_mean': next_prior_mean.detach(),
            'prior_variance': next_prior_variance.detach(),
            'observation_influence': phi.detach(),
            'precision': post_prec.detach(),
            'information_vector': eta.detach(),
            'a_effective': a.detach(),
            'q_effective': q.detach(),
            'alpha': alpha.detach(),
            'alpha_prefix': alpha_prefix.detach(),
        }

        if prior:
            return next_prior_mean, next_prior_variance, extras
        else:
            return post_mean, post_variance, extras

    def forward(
        self,
        latent_token: torch.Tensor,
        token_variance: torch.Tensor,
        h_projected: Optional[torch.Tensor] = None,
        delta_bias: Optional[torch.Tensor] = None
    ):
        B, L, _ = latent_token.shape

        # Step 1: Compute A (already full d_model*d_state)
        A = self._compute_A()

        # Step 2: Process noise - FULL, no expansion needed
        # Already (d_model*d_state,)
        q_base = self.process_noise.clamp(min=1e-7)
        if self.clip_value is not None:
            q_base = q_base.clamp(max=self.clip_value)

        # Step 3: Reshape inputs
        latent_token = latent_token.reshape(B, L, self.d_model, self.mimo_rank)
        token_variance = token_variance.reshape(B, L, self.d_model, self.mimo_rank)

        if h_projected is not None:
            h = h_projected.reshape(B, L, self.d_state, self.mimo_rank)
        else:
            h = self.h.reshape(1, 1, self.d_state, self.mimo_rank).expand(B, L, -1, -1)

        # Step 4: NO expansion for q_base - it's already full (d_model*d_state,)

        # Step 5: Discretize if needed (OU method for both A and Q)
        if self.discretize:
            delta = self.delta
            if delta_bias is not None:
                delta_bias = self._expand_outer_product(delta_bias, self.d_state)
                delta = delta[None, None, :] + delta_bias
            else:
                delta = delta[None, None, :].expand(B, L, -1)

            delta = F.softplus(delta) + 1e-7

            # Discretize A
            a = self._discretize_ou(A, delta)

            # Discretize Q using OU method (q_base is already full)
            q = self._discretize_process_noise_ou(A, q_base, delta)
        else:
            a = A
            q = q_base

        # Step 6: Run Kalman filter
        post_mean, post_variance, extras = self.ssm(
            mu=latent_token,
            var=token_variance,
            h=h,
            a=a,
            q=q,
            return_extras=self.gauss_visualization
        )

        extras['a_effective'] = a.detach() if isinstance(a, torch.Tensor) else a
        extras['q_effective'] = q.detach() if isinstance(q, torch.Tensor) else q

        return post_mean, post_variance, extras


class KLABlock(nn.Module):
    """
    Kalman Linear Attention (KLA) Block.

    A sequence modeling layer combining:
    - Input projection with optional gating
    - Causal convolution
    - Kalman filter (KalmanFilter) for temporal modeling
    - Output projection with skip connection
    """
    def __init__(
        self,
        dim,
        d_state=16,
        discretize=True,
        a_parameterization='log_space',
        selective_h=True,
        selective_process=False,
        selective_delta=False,
        h_feature_map_mode=None,
        gating_mode='multiplicative',
        mlp_hidden_units=None,
        use_causal_conv=False,
        conv_kernel_size=3,
        conv_activation='silu',
        conv_groups=None,
        use_gating=False,
        gating_activation='silu',
        use_fused_norm_gate=False,
        use_lambda_skip=False,
        lambda_skip_mode='scalar',
        lambda_init=-1.0,
        process_noise_scale=0.01,
        dt_min=0.001,
        dt_max=0.1,
        dt_init_floor=1e-4,
        bias=False,
        return_variance=False,
        mimo_rank=1,
        use_qk_norm=False,
        use_c_bias=False,
        use_b_bias=False,
        compile_ssm=True,
        clip_value=None,
        scale_variance_through_gating=True,
        *args, **kwargs
    ):
        super().__init__()
        self.d_model = dim
        self.d_state = d_state if d_state is not None else dim
        self.return_variance = return_variance
        self.gating_mode = gating_mode
        self.selective_h = selective_h
        self.scale_variance_through_gating = scale_variance_through_gating
        self.selective_process = selective_process
        self.selective_delta = selective_delta
        self.mimo_rank = mimo_rank
        self.clip_value = clip_value
        self.use_fused_norm_gate = use_fused_norm_gate

        self.register_buffer('_device_indicator', torch.zeros(1))
        self._last_variance = None
        self._last_extras = None
        self.gauss_visualization = kwargs.get('gauss_visualization', False)

        # Input projection
        if use_gating:
            self.in_proj = nn.Linear(dim, self.d_model * 2, bias=bias)
        else:
            self.in_proj = nn.Linear(dim, self.d_model, bias=bias)

        self.mlp = None

        # Causal conv1d
        self.use_causal_conv = use_causal_conv
        if use_causal_conv:
            self.conv1d = nn.Conv1d(
                in_channels=self.d_model,
                out_channels=self.d_model,
                kernel_size=conv_kernel_size,
                groups=conv_groups if conv_groups is not None else self.d_model,
                padding=conv_kernel_size - 1,
                bias=True
            )
            self.conv_activation = conv_activation

        # Sensor projection
        sensor_output_dim = 2 * self.mimo_rank * self.d_model + self.mimo_rank * self.d_state + self.d_state
        self.sensor_proj = nn.Linear(self.d_model, sensor_output_dim, bias=bias)

        self.use_qk_norm = use_qk_norm
        self.use_c_bias = use_c_bias
        if use_c_bias:
            self.c_bias = nn.Parameter(torch.ones(self.d_state))

        # Selective delta
        if selective_delta and discretize:
            self.delta_proj = nn.Linear(self.d_model, 1, bias=False)
        else:
            self.delta_proj = None

        # Gating
        self.use_gating = use_gating
        if use_gating:
            if gating_activation == 'silu':
                self.gate_act = nn.SiLU()
            elif gating_activation == 'gelu':
                self.gate_act = nn.GELU()
            elif gating_activation == 'relu':
                self.gate_act = nn.ReLU()
            else:
                self.gate_act = nn.SiLU()

        # Lambda skip
        self.lambda_skip_mode = lambda_skip_mode
        if use_lambda_skip:
            if lambda_skip_mode == 'vector':
                self.lambda_param = nn.Parameter(torch.full((self.d_model,), lambda_init))
            else:
                self.lambda_param = nn.Parameter(torch.tensor(lambda_init))
        else:
            self.lambda_param = None

        # Core Kalman filter
        self.ssm = KalmanFilter(
            dim=self.d_model,
            d_state=self.d_state,
            discretize=discretize,
            a_parameterization=a_parameterization,
            dt_min=dt_min,
            dt_max=dt_max,
            dt_init_floor=dt_init_floor,
            process_noise_scale=process_noise_scale,
            selective_h=selective_h,
            mimo_rank=self.mimo_rank,
            use_b_bias=use_b_bias,
            clip_value=self.clip_value,
            gauss_visualization=self.gauss_visualization
        )

        if compile_ssm:
            try:
                self.ssm = torch.compile(self.ssm)
            except Exception as e:
                print(f"Warning: torch.compile failed for KalmanFilter: {e}")

    @property
    def device(self):
        return self._device_indicator.device

    def forward(self, x, *args, **kwargs):
        B, L, D = x.shape

        # Input projection
        projected = self.in_proj(x)
        if self.use_gating:
            z, gate = projected.chunk(2, dim=-1)
        else:
            z = projected
            gate = None

        # Causal conv1d
        if self.use_causal_conv:
            z_conv = z.transpose(-1, -2)
            z_conv = self.conv1d(z_conv)
            z_conv = z_conv[..., :L]
            if self.conv_activation == 'silu':
                z_conv = F.silu(z_conv)
            elif self.conv_activation == 'relu':
                z_conv = F.relu(z_conv)
            elif self.conv_activation == 'gelu':
                z_conv = F.gelu(z_conv)
            else:
                z_conv = F.silu(z_conv)
            z = z_conv.transpose(-1, -2)

        residual = z

        # Sensor projection
        sensor_projected = self.sensor_proj(z)
        latent_token, log_var, h_projected, q_projected = torch.split(
            sensor_projected,
            [self.mimo_rank * self.d_model,
             self.mimo_rank * self.d_model,
             self.mimo_rank * self.d_state,
             self.d_state],
            dim=-1
        )

        # QK norm
        if self.use_qk_norm:
            h_proj_reshape = h_projected.view(B, L, self.mimo_rank, self.d_state)
            h_proj_reshape = l2_norm_fn(h_proj_reshape)
            h_projected = h_proj_reshape.view(B, L, self.mimo_rank * self.d_state).to(latent_token)
            q_projected = l2_norm_fn(q_projected).to(latent_token)

        if self.clip_value is not None:
            h_projected = h_projected.clamp(-self.clip_value, self.clip_value)
            q_projected = q_projected.clamp(-self.clip_value, self.clip_value)

        if self.use_c_bias:
            q_projected = q_projected + self.c_bias

        token_variance = F.softplus(log_var) + 1e-6
        if self.clip_value is not None:
            token_variance = token_variance.clamp(max=self.clip_value)

        # Selective delta
        delta_bias = None
        if self.delta_proj is not None:
            delta_bias = self.delta_proj(z)

        # SSM filtering
        post_mean, post_variance, extras = self.ssm(latent_token, token_variance, h_projected, delta_bias)

        extras['token_variance'] = token_variance.detach()
        extras['h_projected'] = h_projected.detach()
        extras['h_effective'] = h_projected.detach()
        extras['q_projected'] = q_projected.detach()
        extras['latent_token'] = latent_token.detach()

        # Contract to output dimension
        post_mean = post_mean.view(B, L, self.d_model, self.d_state)
        post_variance = post_variance.view(B, L, self.d_model, self.d_state)
        y = torch.matmul(post_mean, q_projected.unsqueeze(-1)).squeeze(-1)
        y_var = torch.matmul(post_variance, q_projected.pow(2).unsqueeze(-1)).squeeze(-1)

        extras['post_variance_expanded'] = post_variance.detach()
        if 'prior_variance' in extras:
            pv = extras['prior_variance'].view(B, L, self.d_model, self.d_state)
            extras['prior_variance'] = torch.matmul(pv, q_projected.pow(2).unsqueeze(-1)).squeeze(-1).detach()

        # Reshape a_effective and q_effective
        if 'a_effective' in extras:
            a_eff = extras['a_effective']
            if a_eff.ndim == 1:
                extras['a_effective'] = a_eff.view(1, 1, self.d_model, self.d_state).expand(B, L, -1, -1).contiguous().detach()
            else:
                extras['a_effective'] = a_eff.view(B, L, self.d_model, self.d_state).detach()

        if 'q_effective' in extras:
            q_eff = extras['q_effective']
            if q_eff.ndim == 1:
                extras['q_effective'] = q_eff.view(1, 1, self.d_model, self.d_state).expand(B, L, -1, -1).contiguous().detach()
            else:
                extras['q_effective'] = q_eff.view(B, L, self.d_model, self.d_state).detach()

        # Lambda skip
        if self.lambda_param is not None:
            y = y + self.lambda_param * residual

        # Gating
        if self.use_gating and gate is not None:
            if self.gating_mode == 'additive':
                y = y + self.gate_act(gate)
            else:
                g = self.gate_act(gate)
                y = y * g
                if self.scale_variance_through_gating:
                    y_var = y_var * g.pow(2)

        self._last_variance = y_var.detach()
        self._last_extras = extras

        if self.return_variance:
            return y, y_var
        else:
            return y


# Backward compatibility aliases
GaussSSMFullCorrected = KalmanFilter
GaussBlockFullCorrected = KLABlock
GaussSSMFull = KalmanFilter
GaussBlockFull = KLABlock
