import jax
import jax.numpy as jnp
import sde.jax.markov_approximation as ma


num_latents = 1

class Model:
    def __init__(self):
        self.gamma = ma.gamma_by_gamma_max(5, 1e2)
        self.omega = ma.omega_optimized_2(self.gamma, 0.7, 1.0)

    def b(self, params, t, x):
        return - x

    def u(self, params, t, x, y):
        return jnp.sin(x) * jnp.cos(t)

    def s(self, params, t, x):
        return jnp.ones_like(x) * jnp.sin(10 * t)

model = Model()
x0 = jnp.zeros(num_latents)
y0 = jnp.zeros((num_latents, len(model.gamma)))
t0 = 0.0
ts, xs, log_path = ma.solve_vector(None, model, model.omega, x0, y0, t0, 10000, 1e-5, jax.random.PRNGKey(7))

print(ts[113], xs[113])
print(log_path)
