"""
Linear-Gaussian SSM simulation and RTS smoother utilities
线性-高斯状态空间模型仿真与RTS平滑工具

提供：
- simulate_lgssm: 1D 生成 (x_k, y_k)
- kalman_rts: 1D 卡尔曼滤波 + RTS 平滑（闭式），返回平滑均值/协方差
- simulate_lgssm_nd / kalman_rts_nd: 通用 N 维版本（支持 1D/2D/…）

All functions are JAX-compatible.
所有函数均为 JAX 兼容实现。
"""

from __future__ import annotations

from typing import Tuple

import jax
import jax.numpy as jnp
from jax import lax


def _chol_psd(M: jax.Array) -> jax.Array:
    """
    Numerically safe Cholesky for PSD matrices (regularize with jitter)
    适用于半正定矩阵的数值安全Cholesky（加入抖动）
    """
    M = 0.5 * (M + M.T)
    eye = jnp.eye(M.shape[0], dtype=M.dtype)
    jitter = 1e-9
    for _ in range(5):
        try:
            return jnp.linalg.cholesky(M + jitter * eye)
        except Exception:
            jitter *= 10.0
    return jnp.linalg.cholesky(M + jitter * eye)


def simulate_lgssm(
    key: jax.Array,
    A: float,
    C: float,
    Q: float,
    R: float,
    K: int,
    x0: float = 0.0,
) -> Tuple[jax.Array, jax.Array]:
    """
    Simulate linear Gaussian SSM (1D)
    仿真 1D 线性高斯 SSM

    x_k = A x_{k-1} + w_k,  y_k = C x_k + v_k,  w_k~N(0,Q), v_k~N(0,R)
    """
    key_w, key_v = jax.random.split(key)
    w = jax.random.normal(key_w, (K,)) * jnp.sqrt(Q)
    v = jax.random.normal(key_v, (K,)) * jnp.sqrt(R)

    def step(carry, t):
        x_prev = carry
        x = A * x_prev + w[t]
        y = C * x + v[t]
        return x, (x, y)

    _, (xs, ys) = lax.scan(step, jnp.array(x0), jnp.arange(K))
    return xs, ys


def kalman_rts(
    y: jax.Array,
    A: float,
    C: float,
    Q: float,
    R: float,
    m0: float = 0.0,
    P0: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
    """
    Kalman filter + RTS smoother for 1D LG-SSM
    1D 线性高斯 SSM 的卡尔曼滤波 + RTS 平滑
    Returns smoothed means (K,) and covariances (K,)
    返回平滑后的均值 (K,) 与协方差 (K,)
    """
    K = y.shape[0]

    def filter_step(carry, t):
        m_prev, P_prev = carry
        # Predict
        m_pred = A * m_prev
        P_pred = A * P_prev * A + Q
        # Update
        S = C * P_pred * C + R
        K_gain = P_pred * C / S
        m = m_pred + K_gain * (y[t] - C * m_pred)
        P = (1.0 - K_gain * C) * P_pred
        return (m, P), (m, P, m_pred, P_pred)

    (mK, PK), (m_f, P_f, m_pred_all, P_pred_all) = lax.scan(
        filter_step, (jnp.array(m0), jnp.array(P0)), jnp.arange(K)
    )

    # RTS smoother
    def smooth_step(carry, t_rev):
        m_next_s, P_next_s = carry
        t = t_rev  # index from K-2 .. 0
        P_pred = P_pred_all[t]
        m_pred = m_pred_all[t]
        P_f_t = P_f[t]
        m_f_t = m_f[t]
        # Gain
        C_t = P_f_t * A / P_pred
        # Smoothed
        m_s = m_f_t + C_t * (m_next_s - m_pred)
        P_s = P_f_t + C_t * (P_next_s - P_pred) * C_t
        return (m_s, P_s), (m_s, P_s)

    # init with last filtered
    init = (m_f[-1], P_f[-1])
    (m_first, P_first), (m_s_list, P_s_list) = lax.scan(
        smooth_step, init, jnp.arange(K - 2, -1, -1)
    )
    m_s = jnp.concatenate([m_s_list[::-1], m_f[-1:]])
    P_s = jnp.concatenate([P_s_list[::-1], P_f[-1:]])
    return m_s, P_s


# ==============================
# General N-D versions / 通用N维
# ==============================

def simulate_lgssm_nd(
    key: jax.Array,
    A: jax.Array,
    C: jax.Array,
    Q: jax.Array,
    R: jax.Array,
    K: int,
    x0: jax.Array,
) -> Tuple[jax.Array, jax.Array]:
    """
    Simulate N-D linear Gaussian SSM
    仿真 N 维线性高斯 SSM
    x_k = A x_{k-1} + w_k,  y_k = C x_k + v_k
    w_k ~ N(0,Q), v_k ~ N(0,R)
    """
    d = x0.shape[0]
    p = C.shape[0]
    key_w, key_v = jax.random.split(key)
    LQ = _chol_psd(Q)
    LR = _chol_psd(R)
    w = (jax.random.normal(key_w, (K, d)) @ LQ.T)
    v = (jax.random.normal(key_v, (K, p)) @ LR.T)

    def step(carry, t):
        x_prev = carry
        x = A @ x_prev + w[t]
        y = C @ x + v[t]
        return x, (x, y)

    _, (xs, ys) = lax.scan(step, x0, jnp.arange(K))
    return xs, ys


def kalman_rts_nd(
    y: jax.Array,
    A: jax.Array,
    C: jax.Array,
    Q: jax.Array,
    R: jax.Array,
    m0: jax.Array,
    P0: jax.Array,
) -> Tuple[jax.Array, jax.Array]:
    """
    Kalman filter + RTS smoother (N-D)
    N 维卡尔曼滤波 + RTS 平滑
    Returns smoothed means (K,d) and covariances (K,d,d)
    返回 (K,d) 平滑均值与 (K,d,d) 协方差
    """
    K = y.shape[0]
    d = m0.shape[0]

    def filter_step(carry, t):
        m_prev, P_prev = carry
        # Predict
        m_pred = A @ m_prev
        P_pred = A @ P_prev @ A.T + Q
        # Update
        S = C @ P_pred @ C.T + R
        K_gain = P_pred @ C.T @ jnp.linalg.inv(S)
        m = m_pred + K_gain @ (y[t] - C @ m_pred)
        I = jnp.eye(d)
        P = (I - K_gain @ C) @ P_pred @ (I - K_gain @ C).T + K_gain @ R @ K_gain.T  # Joseph form
        return (m, P), (m, P, m_pred, P_pred)

    (_, _), (m_f, P_f, m_pred_all, P_pred_all) = lax.scan(
        filter_step, (m0, P0), jnp.arange(K)
    )

    # RTS smoother
    def smooth_step(carry, t_rev):
        m_next_s, P_next_s = carry
        t = t_rev
        P_pred = P_pred_all[t]
        m_pred = m_pred_all[t]
        P_f_t = P_f[t]
        m_f_t = m_f[t]
        C_t = P_f_t @ A.T @ jnp.linalg.inv(P_pred)
        m_s = m_f_t + C_t @ (m_next_s - m_pred)
        P_s = P_f_t + C_t @ (P_next_s - P_pred) @ C_t.T
        return (m_s, P_s), (m_s, P_s)

    init = (m_f[-1], P_f[-1])
    (_, _), (m_s_list, P_s_list) = lax.scan(
        smooth_step, init, jnp.arange(K - 2, -1, -1)
    )
    m_s = jnp.concatenate([m_s_list[::-1], m_f[-1:]])
    P_s = jnp.concatenate([P_s_list[::-1], P_f[-1:]])
    return m_s, P_s


def matrix_infty_norm(M: jax.Array) -> jax.Array:
    """Induced infinity norm (max row sum) / 诱导无穷范数（最大行绝对值和）"""
    return jnp.max(jnp.sum(jnp.abs(M), axis=-1))



