'''
discretisation step functions, for different sampling methods
Take inputs, and return the next step
We can later compuate JVPs and VJPs of these functions, for the single-step Jacobians
'''

import jax
import jax.numpy as jnp

# Euler-Maruyama sampler
def euler_maruyama_sampler(
        x,
        t,
        dB_t,
        dt,
        drift_fn,
        sigma_fn,
):
    drift = drift_fn(x, t)
    sigma, apply_sigma_fn = sigma_fn(x, t)
    noise_term = apply_sigma_fn(dB_t)
    x_next = x + drift * dt + noise_term
    return x_next, drift