import jax
import jax.numpy as jnp
from .multistep import ab_step, get_ab_eps_coef
from .rk import get_rk_fn
from .sde import MultiStepSDE, get_rev_ts
from .vpsde import VPSDE
from .helper import jax2th, th2jax

from functions.denoising import solve_ditillation_objective
import torch
import torch.nn.functional as F

def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

def teacher_fori_loop(lower, upper, body_fun, init_val):
    xs = []
    val = init_val
    xs.append(val[0].to('cpu'))
    for i in range(lower, upper):
        val = body_fun(i, val)
        xs.append(val[0].to('cpu'))
    return xs, val

def get_lambdas_fori_loop(lower, upper, body_fun, init_val):
    xs = []
    lambdas = torch.zeros(upper, dtype=torch.float64)
    val = init_val
    xs.append(val[0])
    for i in range(lower, upper):
        val, lambdas[i] = body_fun(i, val)
        xs.append(val[0])
    return xs, val, lambdas


def distilled_ab_step(x, ei_coef, new_eps, eps_pred, alpha_t, alpha_next):
    x_coef, eps_coef = ei_coef[0], ei_coef[1:]
    full_eps_pred = [ new_eps, *eps_pred]
    epsilon=0
    for cur_coef, cur_eps in zip(eps_coef, full_eps_pred):
        epsilon += cur_coef * cur_eps
    epsilon = epsilon/(-x_coef*(1-alpha_t).sqrt()+(1-alpha_next).sqrt()) 
    x0 = (x-epsilon*(1-alpha_t).sqrt())/alpha_t.sqrt()

    x_next = alpha_next.sqrt()*x0+(1-alpha_next).sqrt()*epsilon

    return x_next, full_eps_pred[:-1], epsilon

def cond_fn(x, t_discrete, y, classifier, classifier_scale):
    assert y is not None
    with torch.enable_grad():
        x_in = x.detach().requires_grad_(True)
        vec_t = (torch.ones(x.shape[0])).float().to(x.device) * t_discrete
        # ! the checkpoint need vec_t shift 1 :(
        logits = classifier(x_in, vec_t-1)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale

def get_sampler(sde, eps_fn, num_step, ts_phase="t", ts_order=2.0, method="rho_rk",ab_order=3, rk_method="3kutta", classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    if method.lower() == "rho_rk":
        return get_sampler_rho_rk(sde, eps_fn, ts_phase, ts_order, num_step, rk_method)
    elif method.lower() == "rho_ab":
        return get_sampler_rho_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order)
    elif method.lower() == "t_ab":
        return get_sampler_t_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order, classifier, is_cond_classifier,  classifier_scale, **model_kwargs)
    elif method.lower() == "ipndm":
        return get_sampler_ipndm(sde, eps_fn, num_step, classifier, is_cond_classifier,  classifier_scale, **model_kwargs)
    raise RuntimeError(f"{method} not support!!")

def d_ode_get_sampler(sde, eps_fn, num_step, ts_phase="t", ts_order=2.0, method="teacher_t_ab", ab_order=3, lambdas=None, teacher_targets=None, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    if method.lower() == "teacher_t_ab":
        return get_sampler_teacher_t_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order, classifier, is_cond_classifier,  classifier_scale, **model_kwargs)
    elif method.lower() == "teacher_ipndm":
        return get_sampler_teacher_ipndm(sde, eps_fn, num_step, classifier, is_cond_classifier,  classifier_scale, **model_kwargs)

    elif method.lower() == "distilled_t_ab": 
        return get_sampler_distilled_t_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order, lambdas, teacher_targets, classifier, is_cond_classifier,  classifier_scale, **model_kwargs)
    elif method.lower() == "distilled_ipndm": 
        return get_sampler_distilled_ipndm(sde, eps_fn, num_step, lambdas, teacher_targets, classifier, is_cond_classifier,  classifier_scale, **model_kwargs) 
    raise RuntimeError(f"{method} not support!!")

######################################################################
#tab_DEIS
def get_sampler_t_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    def model(x, t_discrete, at):
        with torch.no_grad():
            if is_cond_classifier:
                y = model_kwargs.get("y", None)
                if y is None:
                    raise ValueError("For classifier guidance, the label y has to be in the input.")
                noise_uncond = eps_fn(x, t_discrete, **model_kwargs)
                cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale)
                
                sigma_t = (1 - at).sqrt()
                return noise_uncond - sigma_t * cond_grad
            else:
                return eps_fn(x, t_discrete, **model_kwargs)

    jax_rev_ts = get_rev_ts(sde, num_step, ts_order, ts_phase=ts_phase)
    
    x_coef = sde.psi(jax_rev_ts[:-1], jax_rev_ts[1:])
    eps_coef = get_ab_eps_coef(sde, ab_order, jax_rev_ts, ab_order)
    jax_ab_coef = jnp.concatenate([x_coef[:, None], eps_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    def sampler(xT):
        rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
        def ab_body_fn(i, val):
            x, eps_pred = val
            s_t= rev_ts[i]
            
            alpha_t=th_alpha_array[i]

            new_eps = model(x, s_t, alpha_t)
            new_x, new_eps_pred = ab_step(x, ab_coef[i], new_eps, eps_pred)
            return new_x, new_eps_pred

        eps_pred = [xT,] * ab_order
        img, _ = fori_loop(0, num_step, ab_body_fn, (xT, eps_pred))
        return img
    return sampler

def get_sampler_teacher_t_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    def model(x, t_discrete, at):
        with torch.no_grad():
            if is_cond_classifier:
                y = model_kwargs.get("y", None)
                if y is None:
                    raise ValueError("For classifier guidance, the label y has to be in the input.")
                noise_uncond = eps_fn(x, t_discrete, **model_kwargs)
                cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale)
                
                sigma_t = (1 - at).sqrt()
                return noise_uncond - sigma_t * cond_grad
            else:
                return eps_fn(x, t_discrete, **model_kwargs)

    jax_rev_ts = get_rev_ts(sde, num_step, ts_order, ts_phase=ts_phase)
    
    x_coef = sde.psi(jax_rev_ts[:-1], jax_rev_ts[1:])
    eps_coef = get_ab_eps_coef(sde, ab_order, jax_rev_ts, ab_order)
    jax_ab_coef = jnp.concatenate([x_coef[:, None], eps_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    def sampler(xT):
        rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
        def ab_body_fn(i, val):
            x, eps_pred = val
            s_t= rev_ts[i]
            
            alpha_t=th_alpha_array[i]

            new_eps = model(x, s_t, alpha_t)
            new_x, new_eps_pred = ab_step(x, ab_coef[i], new_eps, eps_pred)
            return new_x, new_eps_pred

        eps_pred = [xT,] * ab_order
        xs, val = teacher_fori_loop(0, num_step, ab_body_fn, (xT, eps_pred))
        return xs, val[0]
    return sampler

def get_sampler_distilled_t_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order, lambdas, teacher_targets, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    def model(x, t_discrete, at):
        with torch.no_grad():
            if is_cond_classifier:
                y = model_kwargs.get("y", None)
                if y is None:
                    raise ValueError("For classifier guidance, the label y has to be in the input.")
                noise_uncond = eps_fn(x, t_discrete, **model_kwargs)
                cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale)
                
                sigma_t = (1 - at).sqrt()
                return noise_uncond - sigma_t * cond_grad
            else:
                return eps_fn(x, t_discrete, **model_kwargs)

    jax_rev_ts = get_rev_ts(sde, num_step, ts_order, ts_phase=ts_phase)
    
    x_coef = sde.psi(jax_rev_ts[:-1], jax_rev_ts[1:])
    eps_coef = get_ab_eps_coef(sde, ab_order, jax_rev_ts, ab_order)
    jax_ab_coef = jnp.concatenate([x_coef[:, None], eps_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    if lambdas == None:
        def sampler(xT):
            rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
            eps_old = None
            def ab_body_fn(i, val):
                x, eps_pred, eps_old = val
                s_t= rev_ts[i]
                
                alpha_t=th_alpha_array[i]
                alpha_next=th_alpha_array[i+1] 

                new_eps = model(x, s_t, alpha_t)
                new_x, new_eps_pred, eps_new= distilled_ab_step(x, ab_coef[i], new_eps, eps_pred, alpha_t, alpha_next)

                x_target = teacher_targets[i+1:i+2] - new_x.unsqueeze(0).to('cpu')
                if eps_old == None:
                    e_prev = (eps_new - x).unsqueeze(0) 
                else:
                    e_prev = (eps_new - eps_old).unsqueeze(0)
                lambdas_i, err = solve_ditillation_objective(e_prev.to('cpu'), x_target)
                #print(f'{i+1}th step error: {err}', flush=True)

                new_x += lambdas_i.to(xT.device)*e_prev.squeeze(0)
                return (new_x, new_eps_pred, eps_new), lambdas_i

            eps_pred = [xT,] * ab_order
            xs, img_eps, lambdas = get_lambdas_fori_loop(0, num_step, ab_body_fn, (xT, eps_pred, eps_old))
            return xs, img_eps[0], lambdas.to(xT.device)
    else:
        def sampler(xT):
            rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
            eps_old = None
            def ab_body_fn(i, val):
                x, eps_pred, eps_old = val
                s_t= rev_ts[i]
                
                alpha_t=th_alpha_array[i]
                alpha_next=th_alpha_array[i+1] 

                new_eps = model(x, s_t, alpha_t)
                new_x, new_eps_pred, eps_new = distilled_ab_step(x, ab_coef[i], new_eps, eps_pred, alpha_t, alpha_next)
                if eps_old == None:
                    new_x += lambdas[i]*(eps_new - x)
                else:
                    new_x += lambdas[i]*(eps_new - eps_old)
                return new_x, new_eps_pred, eps_new

            eps_pred = [xT,] * ab_order
            img, _ , _ = fori_loop(0, num_step, ab_body_fn, (xT, eps_pred, eps_old))
            return img        
    return sampler


######################################################################
#iPNDM

def get_sampler_ipndm(sde, eps_fn, num_step, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    def model(x, t_discrete, at):
        with torch.no_grad():
            if is_cond_classifier:
                y = model_kwargs.get("y", None)
                if y is None:
                    raise ValueError("For classifier guidance, the label y has to be in the input.")
                noise_uncond = eps_fn(x, t_discrete, **model_kwargs)
                cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale)
                
                sigma_t = (1 - at).sqrt()
                return noise_uncond - sigma_t * cond_grad
            else:
                return eps_fn(x, t_discrete, **model_kwargs)

    assert isinstance(sde, VPSDE)
    jax_rev_ts = get_rev_ts(sde, num_step, 1, ts_phase="t")
    
    x_coef = sde.psi(jax_rev_ts[:-1], jax_rev_ts[1:]) #(n, )

    def get_linear_ab_coef(i):
        if i == 0:
            return jnp.asarray([1.0, 0, 0, 0]).reshape(-1,4)
        prev_coef = get_linear_ab_coef(i-1)
        cur_coef = None
        if i == 1:
            cur_coef = jnp.asarray([1.5, -0.5, 0, 0])
        elif i == 2:
            cur_coef = jnp.asarray([23, -16, 5, 0]) / 12.0
        else:
            cur_coef = jnp.asarray([55, -59, 37, -9]) / 24.0
        return jnp.concatenate(
            [prev_coef, cur_coef.reshape(-1,4)]
        )
    jax_linear_ab_coef = get_linear_ab_coef(len(jax_rev_ts) - 2) # (n, 4)

    jax_next_ts, jax_cur_ts = jax_rev_ts[1:], jax_rev_ts[:-1]
    jax_next_alpha, jax_cur_alpha = sde.t2alpha_fn(jax_next_ts), sde.t2alpha_fn(jax_cur_ts)
    jax_ddim_coef = jnp.sqrt(1 - jax_next_alpha) - jnp.sqrt(jax_next_alpha / jax_cur_alpha) * jnp.sqrt(1 - jax_cur_alpha) # (n,)

    jax_eps_coef = jax_ddim_coef.reshape(-1,1) * jax_linear_ab_coef

    jax_ab_coef = jnp.concatenate([x_coef[:, None], jax_eps_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    def sampler(xT):
        rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
        def ab_body_fn(i, val):
            x, eps_pred = val
            s_t= rev_ts[i]
            
            alpha_t=th_alpha_array[i]

            new_eps = model(x, s_t, alpha_t)
            new_x, new_eps_pred = ab_step(x, ab_coef[i], new_eps, eps_pred)
            return new_x, new_eps_pred

        eps_pred = [xT,] * 3
        img, _ = fori_loop(0, num_step, ab_body_fn, (xT, eps_pred))
        return img
    return sampler

def get_sampler_teacher_ipndm(sde, eps_fn, num_step, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    def model(x, t_discrete, at):
        with torch.no_grad():
            if is_cond_classifier:
                y = model_kwargs.get("y", None)
                if y is None:
                    raise ValueError("For classifier guidance, the label y has to be in the input.")
                noise_uncond = eps_fn(x, t_discrete, **model_kwargs)
                cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale)
                
                sigma_t = (1 - at).sqrt()
                return noise_uncond - sigma_t * cond_grad
            else:
                return eps_fn(x, t_discrete, **model_kwargs)

    assert isinstance(sde, VPSDE)
    jax_rev_ts = get_rev_ts(sde, num_step, 1, ts_phase="t")
    
    x_coef = sde.psi(jax_rev_ts[:-1], jax_rev_ts[1:]) #(n, )

    def get_linear_ab_coef(i):
        if i == 0:
            return jnp.asarray([1.0, 0, 0, 0]).reshape(-1,4)
        prev_coef = get_linear_ab_coef(i-1)
        cur_coef = None
        if i == 1:
            cur_coef = jnp.asarray([1.5, -0.5, 0, 0])
        elif i == 2:
            cur_coef = jnp.asarray([23, -16, 5, 0]) / 12.0
        else:
            cur_coef = jnp.asarray([55, -59, 37, -9]) / 24.0
        return jnp.concatenate(
            [prev_coef, cur_coef.reshape(-1,4)]
        )
    jax_linear_ab_coef = get_linear_ab_coef(len(jax_rev_ts) - 2) # (n, 4)

    jax_next_ts, jax_cur_ts = jax_rev_ts[1:], jax_rev_ts[:-1]
    jax_next_alpha, jax_cur_alpha = sde.t2alpha_fn(jax_next_ts), sde.t2alpha_fn(jax_cur_ts)
    jax_ddim_coef = jnp.sqrt(1 - jax_next_alpha) - jnp.sqrt(jax_next_alpha / jax_cur_alpha) * jnp.sqrt(1 - jax_cur_alpha) # (n,)

    jax_eps_coef = jax_ddim_coef.reshape(-1,1) * jax_linear_ab_coef

    jax_ab_coef = jnp.concatenate([x_coef[:, None], jax_eps_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    def sampler(xT):
        rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
        def ab_body_fn(i, val):
            x, eps_pred = val
            s_t= rev_ts[i]
            
            alpha_t=th_alpha_array[i]

            new_eps = model(x, s_t, alpha_t)
            new_x, new_eps_pred = ab_step(x, ab_coef[i], new_eps, eps_pred)
            return new_x, new_eps_pred

        eps_pred = [xT,] * 3
        xs, val = teacher_fori_loop(0, num_step, ab_body_fn, (xT, eps_pred))
        return xs, val[0]
    return sampler

def get_sampler_distilled_ipndm(sde, eps_fn, num_step, lambdas, teacher_targets, classifier=None, is_cond_classifier=False,  classifier_scale=1.0, **model_kwargs):
    def model(x, t_discrete, at):
        with torch.no_grad():
            if is_cond_classifier:
                y = model_kwargs.get("y", None)
                if y is None:
                    raise ValueError("For classifier guidance, the label y has to be in the input.")
                noise_uncond = eps_fn(x, t_discrete, **model_kwargs)
                cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale)
                
                sigma_t = (1 - at).sqrt()
                return noise_uncond - sigma_t * cond_grad
            else:
                return eps_fn(x, t_discrete, **model_kwargs)

    assert isinstance(sde, VPSDE)
    jax_rev_ts = get_rev_ts(sde, num_step, 1, ts_phase="t")
    
    x_coef = sde.psi(jax_rev_ts[:-1], jax_rev_ts[1:]) #(n, )

    def get_linear_ab_coef(i):
        if i == 0:
            return jnp.asarray([1.0, 0, 0, 0]).reshape(-1,4)
        prev_coef = get_linear_ab_coef(i-1)
        cur_coef = None
        if i == 1:
            cur_coef = jnp.asarray([1.5, -0.5, 0, 0])
        elif i == 2:
            cur_coef = jnp.asarray([23, -16, 5, 0]) / 12.0
        else:
            cur_coef = jnp.asarray([55, -59, 37, -9]) / 24.0
        return jnp.concatenate(
            [prev_coef, cur_coef.reshape(-1,4)]
        )
    jax_linear_ab_coef = get_linear_ab_coef(len(jax_rev_ts) - 2) # (n, 4)

    jax_next_ts, jax_cur_ts = jax_rev_ts[1:], jax_rev_ts[:-1]
    jax_next_alpha, jax_cur_alpha = sde.t2alpha_fn(jax_next_ts), sde.t2alpha_fn(jax_cur_ts)
    jax_ddim_coef = jnp.sqrt(1 - jax_next_alpha) - jnp.sqrt(jax_next_alpha / jax_cur_alpha) * jnp.sqrt(1 - jax_cur_alpha) # (n,)

    jax_eps_coef = jax_ddim_coef.reshape(-1,1) * jax_linear_ab_coef

    jax_ab_coef = jnp.concatenate([x_coef[:, None], jax_eps_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)

    jax_alpha_array = sde.t2alpha_fn(jax_rev_ts)
    th_alpha_array = jax2th(jax_alpha_array)

    if lambdas == None:
        def sampler(xT):
            rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
            eps_old = None
            def ab_body_fn(i, val):
                x, eps_pred, eps_old = val
                s_t= rev_ts[i]
                
                alpha_t=th_alpha_array[i]
                alpha_next=th_alpha_array[i+1] 

                new_eps = model(x, s_t, alpha_t)
                new_x, new_eps_pred, eps_new = distilled_ab_step(x, ab_coef[i], new_eps, eps_pred, alpha_t, alpha_next)

                x_target = teacher_targets[i+1:i+2].to(xT.device) - new_x.unsqueeze(0)
                if eps_old == None:
                    e_prev = (eps_new - x).unsqueeze(0) 
                else:
                    e_prev = (eps_new - eps_old).unsqueeze(0)
                lambdas_i, err = solve_ditillation_objective(e_prev, x_target)
                #print(f'{i+1}th step error: {err}', flush=True)

                new_x += lambdas_i*e_prev.squeeze(0)

                return (new_x, new_eps_pred, eps_new), lambdas_i

            eps_pred = [xT,] * 3
            xs, imgs, lambdas = get_lambdas_fori_loop(0, num_step, ab_body_fn, (xT, eps_pred, eps_old))
            return xs, imgs[0], lambdas 
    else:
        def sampler(xT):
            rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
            eps_old = None
            def ab_body_fn(i, val):
                x, eps_pred, eps_old = val
                s_t= rev_ts[i]
                
                alpha_t=th_alpha_array[i]
                alpha_next=th_alpha_array[i+1] 
                
                new_eps = model(x, s_t, alpha_t)
                new_x, new_eps_pred, eps_new = distilled_ab_step(x, ab_coef[i], new_eps, eps_pred, alpha_t, alpha_next)
                if eps_old == None:
                    new_x += lambdas[i]*(eps_new - x)
                else:
                    new_x += lambdas[i]*(eps_new - eps_old)
                return new_x, new_eps_pred, eps_new

            eps_pred = [xT,] * 3
            imgs = fori_loop(0, num_step, ab_body_fn, (xT, eps_pred, eps_old))
            return imgs[0]
    return sampler


######################################################################
#etc

def get_sampler_rho_ab(sde, eps_fn, ts_phase, ts_order, num_step, ab_order):
    jax_rev_ts = get_rev_ts(sde, num_step, ts_order, ts_phase=ts_phase)
    
    highest_order = ab_order
    x_coef = jnp.ones(jax_rev_ts.shape[0]-1)
    jax_rev_rhos = sde.t2rho(jax_rev_ts)
    class HelperSDE(MultiStepSDE):
        def psi(cls, t1, t2):
            return t1 / t1 * t2 / t2
        def eps_integrand(cls, vec_t):
            return vec_t / vec_t
    eps_ab_coef = get_ab_eps_coef(HelperSDE(), highest_order, jax_rev_rhos, ab_order)
    jax_ab_coef = jnp.concatenate([x_coef[:, None], eps_ab_coef], axis=1)
    th_rev_ts, th_ab_coef = jax2th(jax_rev_ts), jax2th(jax_ab_coef)
    nfe = len(th_rev_ts) - 1

    def eps_fn_vrho(v, jax_cur_rho):
        jax_cur_t = sde.rho2t(jax_cur_rho)
        x = sde.v2x(v, jax_cur_t)
        return eps_fn(x, jax2th(jax_cur_t, x))
    
    def sampler(xT):
        rev_ts, ab_coef = th_rev_ts.to(xT.device), th_ab_coef.to(xT.device)
        vT = sde.x2v(xT, jax_rev_ts[0])
        def ab_body_fn(i, val):
            v_cur, eps_prev_preds = val # eps_preds (highest_order, ) start with prev_rho
            rho_cur = jax_rev_rhos[i]
            eps_cur = eps_fn_vrho(v_cur, rho_cur)

            v_next, new_eps_cur_preds = ab_step(v_cur, ab_coef[i], eps_cur, eps_prev_preds)
            return v_next, new_eps_cur_preds

        eps_pred = [xT,] * highest_order
        v_eps, _ = fori_loop(0, nfe, ab_body_fn, (vT, eps_pred))
        x_eps = sde.v2x(v_eps, jax_rev_ts[-1])
        return x_eps
    return sampler

def get_sampler_rho_rk(sde, eps_fn, ts_phase, ts_order, num_step, rk_method):
    jax_rev_ts = get_rev_ts(sde, num_step, ts_order, ts_phase=ts_phase)
    rk_fn = get_rk_fn(rk_method)
    jax_rev_rhos = sde.t2rho(jax_rev_ts)
    th_rev_ts, th_rev_rhos = jax2th(jax_rev_ts), jax2th(jax_rev_rhos)


    def eps_fn_vrho(v, th_rho):
        jax_t = sde.rho2t(th2jax(th_rho))
        x = sde.v2x(v, jax_t)
        return eps_fn(x, jax2th(jax_t, x))

    def _step_fn(i_th, v):
        rho_cur, rho_next = th_rev_rhos[i_th], th_rev_rhos[i_th + 1]
        delta_t = rho_next - rho_cur
        return rk_fn(v, rho_cur, delta_t, eps_fn_vrho)

    def sample_fn(xT):
        vT = sde.x2v(xT, jax_rev_ts[0])
        veps = fori_loop(0, len(th_rev_rhos)-1, _step_fn, vT)
        xeps = sde.v2x(veps, jax_rev_ts[-1])
        return xeps

    return sample_fn
