import jax
import jax.numpy as jnp
from jax import random

from src.baselines.smc import SMCSmoother, SMCConfig


def generate_linear_gaussian_ssm(key, T=60, state_dim=2, obs_dim=2, process_std=0.1, obs_std=0.05):
    A = jnp.array([[0.95, 0.1], [0.0, 0.98]])
    H = jnp.eye(obs_dim, state_dim)
    x0 = jnp.zeros((state_dim,))
    k1, k2 = random.split(key)
    w = random.normal(k1, (T, state_dim)) * process_std
    v = random.normal(k2, (T, obs_dim)) * obs_std
    def step(carry, t):
        x_prev = carry
        x = A @ x_prev + w[t]
        y = H @ x + v[t]
        return x, (x, y)
    _, (xs, ys) = jax.lax.scan(step, x0, jnp.arange(T))
    return A, H, xs, ys


def make_model_funcs(A, H, process_std, obs_std):
    Q_var = (process_std ** 2)
    R_var = (obs_std ** 2)

    def transition_sampler(x_prev, key):
        noise = random.normal(key, x_prev.shape) * process_std
        return A @ x_prev + noise

    def transition_logpdf(x_curr, x_prev):
        diff = x_curr - A @ x_prev
        return -0.5 * (
            jnp.sum((diff ** 2) / Q_var) + diff.size * jnp.log(2 * jnp.pi) + diff.size * jnp.log(Q_var)
        )

    def obs_logpdf(y_t, x_t):
        diff = y_t - H @ x_t
        return -0.5 * (
            jnp.sum((diff ** 2) / R_var) + diff.size * jnp.log(2 * jnp.pi) + diff.size * jnp.log(R_var)
        )

    return transition_sampler, transition_logpdf, obs_logpdf


def test_smc_pf_ffbsi_shapes_weights_and_ess():
    key = random.PRNGKey(0)
    T = 60
    D = 2
    Dy = 2
    process_std = 0.1
    obs_std = 0.05
    A, H, xs, ys = generate_linear_gaussian_ssm(key, T=T, state_dim=D, obs_dim=Dy, process_std=process_std, obs_std=obs_std)
    ts, tlp, olp = make_model_funcs(A, H, process_std, obs_std)

    cfg = SMCConfig(num_particles=512, num_smoothing_samples=32, resample_threshold=0.5, jitter_std=0.0)
    smc = SMCSmoother(cfg, transition_sampler=ts, transition_logpdf=tlp, obs_logpdf=olp)

    key_f, key_b = random.split(key)
    filt = smc.run_filter(key_f, ys, x_dim=D)

    # Shapes
    assert filt.particles.shape == (T, cfg.num_particles, D)
    assert filt.log_weights.shape == (T, cfg.num_particles)
    assert filt.ancestors.shape == (T, cfg.num_particles)
    assert filt.ess.shape == (T,)

    # Weight normalization per time step
    W = jax.nn.softmax(filt.log_weights, axis=-1)
    sums = jnp.sum(W, axis=-1)
    assert jnp.allclose(sums, jnp.ones((T,)), rtol=1e-5, atol=1e-6)

    # ESS reasonable (strict bounds)
    ess_calc = 1.0 / jnp.sum(W ** 2, axis=-1)  # (T,)
    assert jnp.all(ess_calc > 1.0)
    assert jnp.all(ess_calc <= cfg.num_particles + 1e-5)
    # Internal ESS (based on pre-resample weights) also in bounds
    assert jnp.all(filt.ess > 1.0)
    assert jnp.all(filt.ess <= cfg.num_particles + 1e-5)

    # Backward simulation smoother
    bwd = smc.run_smoother(key_b, filt)
    assert bwd.trajectories.shape == (cfg.num_smoothing_samples, T, D)
    assert jnp.all(jnp.isfinite(bwd.trajectories))


