import jax
import jax.numpy as jnp
from jax import random

from src.baselines.SVGPSSM import SVGPSSM, SVGPSSMConfig


def generate_linear_ssm(key, T=50, state_dim=2, obs_dim=2, process_std=0.1, obs_std=0.05):
    A = jnp.array([[0.9, 0.1], [0.0, 0.95]])
    H = jnp.eye(obs_dim, state_dim)
    x0 = jnp.zeros((state_dim,))
    key1, key2, key3 = random.split(key, 3)
    noise_p = random.normal(key1, (T, state_dim)) * process_std
    noise_o = random.normal(key2, (T, obs_dim)) * obs_std
    def body(carry, t):
        x_prev = carry
        x = A @ x_prev + noise_p[t]
        y = H @ x + noise_o[t]
        return x, (x, y)
    _, (xs, ys) = jax.lax.scan(body, x0, jnp.arange(T))
    return xs, ys


def test_svgpssm_runs_and_improves_elbo():
    key = random.PRNGKey(0)
    T = 60
    state_dim = 2
    obs_dim = 2
    xs, ys = generate_linear_ssm(key, T=T, state_dim=state_dim, obs_dim=obs_dim)
    cfg = SVGPSSMConfig(state_dim=state_dim, obs_dim=obs_dim, n_inducing=16, mc_samples=4, learning_rate=5e-3)
    model = SVGPSSM(cfg)
    state = model.fit(key, ys, num_steps=50)
    # basic sanity: elbo should be finite and predictions shape correct
    assert jnp.isfinite(state.elbo)
    pred = model.predict(state.params, state.variational)
    assert pred.shape == (T, state_dim)


