import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random

from diffuse.conditional import CondSDE, SDEState, pmcmc


class IdentityMask:
    def make(self, x):
        return jnp.ones_like(x)


class GaussianSDE(CondSDE):
    def __init__(self, mu, sigma):
        super().__init__(
            beta=lambda t: 2 * sigma**2,
            mask=IdentityMask(),
            tf=1.0,
            score=lambda x, t: (mu - x) / sigma**2,
        )
        self.mu = mu
        self.sigma = sigma

    def path(self, key, state, t):
        x, _ = state
        drift = self.mu - x
        diffusion = self.sigma * jax.random.normal(key, x.shape)
        new_x = x + drift * t + diffusion * jnp.sqrt(t)
        return SDEState(new_x, t)


@pytest.mark.parametrize("mu,sigma", [(0.0, 1.0), (2.0, 1.5)])
def test_pmcmc_gaussian(mu, sigma):
    gaussian_sde = GaussianSDE(mu, sigma)

    key = random.PRNGKey(0)
    key, subkey = random.split(key)

    n_particles = 1000
    x_shape = (1,)
    x_p = random.normal(key, (n_particles,) + x_shape)
    log_Z_p = 0.0

    y = random.normal(subkey, x_shape)
    xi = y  # For Gaussian, observation is the same as the latent state

    # Run PMCMC
    key, subkey = random.split(key)
    particles, log_Z = pmcmc(x_p, log_Z_p, subkey, y, xi, gaussian_sde)

    # Check results
    mean_estimate = jnp.mean(particles)
    std_estimate = jnp.std(particles)

    print(f"True mu: {mu}, Estimated mean: {mean_estimate}")
    print(f"True sigma: {sigma}, Estimated std: {std_estimate}")

    # Assert that the estimates are close to the true values
    np.testing.assert_allclose(mean_estimate, mu, atol=0.1)
    np.testing.assert_allclose(std_estimate, sigma, atol=0.1)
