import jax
import jax.numpy as jnp
from jax import random

from src.baselines.PRSSM import PRSSM, PRSSMConfig


def generate_linear_ssm(key, T=40, state_dim=2, obs_dim=2, process_std=0.1, obs_std=0.05):
    A = jnp.array([[0.95, 0.1], [0.0, 0.98]])
    C = jnp.eye(obs_dim, state_dim)
    x0 = jnp.zeros((state_dim,))
    key1, key2 = random.split(key)
    w = random.normal(key1, (T, state_dim)) * process_std
    v = random.normal(key2, (T, obs_dim)) * obs_std
    def step(carry, t):
        x_prev = carry
        x = A @ x_prev + w[t]
        y = C @ x + v[t]
        return x, (x, y)
    _, (xs, ys) = jax.lax.scan(step, x0, jnp.arange(T))
    return xs, ys


def test_prssm_runs_and_has_reasonable_shapes():
    key = random.PRNGKey(0)
    T = 40
    d = 2
    p = 2
    xs, ys = generate_linear_ssm(key, T=T, state_dim=d, obs_dim=p)
    cfg = PRSSMConfig(state_dim=d, obs_dim=p, n_inducing=16, rec_hidden=32, mc_samples=4, learning_rate=3e-3)
    model = PRSSM(cfg)
    state = model.fit(key, ys, num_steps=50)
    assert jnp.isfinite(state.elbo)


