import jax
import jax.numpy as jnp
from jax import random

from src.baselines.SVGPSSM import SVGPSSMConfig, SVGPSSM, SVGPSSMStructured


def generate_linear_ssm(key, T=40, state_dim=2, obs_dim=2, process_std=0.05, obs_std=0.05):
    A = jnp.array([[0.95, 0.1], [0.0, 0.98]])
    H = jnp.eye(obs_dim, state_dim)
    x0 = jnp.zeros((state_dim,))
    key1, key2 = random.split(key)
    w = random.normal(key1, (T, state_dim)) * process_std
    v = random.normal(key2, (T, obs_dim)) * obs_std
    def step(carry, t):
        x_prev = carry
        x = A @ x_prev + w[t]
        y = H @ x + v[t]
        return x, (x, y)
    _, (xs, ys) = jax.lax.scan(step, x0, jnp.arange(T))
    return xs, ys


def test_structured_variational_runs():
    key = random.PRNGKey(1)
    T = 40
    xs, ys = generate_linear_ssm(key, T=T)
    cfg = SVGPSSMConfig(state_dim=2, obs_dim=2, n_inducing=12, process_std=0.05)
    base = SVGPSSM(cfg)
    params, var, _ = base.init(key, T, 2)
    model = SVGPSSMStructured(cfg)
    st = model.fit_once(params, ys)
    assert st.means.shape == (T, 2)
    assert st.covs.shape == (T, 2, 2)
    # Basic sanity: covariance positive definite-ish
    eig = jnp.linalg.eigvalsh(st.covs[0])
    assert jnp.all(eig > 0)


