'''
Contains the ConsistencyBridge class.

Passes a 'bridge_config' object to the class, which determines the bridge parameterisation:
- base_drift: bool, whether to include the base drift in the controlled SDE
- guiding_type: 'linearised' or 'brownian', type of guiding drift to use
- decay_coeff: bool, whether to use a time-decaying coefficient for the neural network adjustment
- sampler: 'euler', 'heun', 'milstein'. This is the sampler that is differentiated through when calculating the training targets.

This works for a general sigma function, which can be a scalar, diagonal, or full matrix, and can depend on (x,t). The code automatically constructs the necessary functions to avoid excessive matrix multiplication when it is not needed.

Also supports STL adjustments when calculating the training targets. This is passed as an argument to the train_config dictionary.
'''

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import flax
import flax.linen as nn
from functools import partial
from typing import Any, Callable
from tqdm import trange
import wandb

from samplers import euler_maruyama_sampler

# TrainState dataclass
@flax.struct.dataclass
class TrainState:
    step: int
    params: Any
    ema_params: Any
    opt_state: Any
    ema_grads: Any
    lr: float
    ts: Any
    ema_rate: float = 0.01
    grad_ema_rate: float = 0.01


# ==============================================================================
# Define required functions as pure functions for jitting
# ==============================================================================

def _controlled_drift_fn(params, x, t, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn):
    """
    Returns the controlled drift. This can optionally include a base drift term, and a 'decay' multiplier on the neural adjustment.
    """
    base_drift_val = optional_base_drift_fn(x, t)
    guiding_drift_val = guiding_drift_fn(x, t)
    _, apply_sigma_fn = sigma_fn(x,t)
    model_output = model.apply(params, x, t)
    neural_adjustment = coeff_fn(t) * apply_sigma_fn(model_output)
    return base_drift_val + guiding_drift_val + neural_adjustment

def _sample_sde_fn(key, drift_fn, sigma_fn, sampler_fn, x_0, shape, ts):
    """
    Samples the SDE for the given drift, using the given sampler function and time discretisation.
    """
    def step(carry, t):
        x, key, prev_t = carry
        key, subkey = jax.random.split(key)
        dt = t - prev_t
        dB_t = jax.random.normal(subkey, shape=shape) * jnp.sqrt(dt)
        x_next, drift = sampler_fn(x, prev_t, dB_t, dt, drift_fn, sigma_fn)
        output = (x_next, drift, dB_t)
        return (x_next, key, t), output

    carry = (x_0, key, ts[0])
    _, outputs = jax.lax.scan(step, carry, ts[1:])
    traj, drift_traj, dBt_traj = outputs
    
    traj = jnp.vstack([x_0[None, :], traj])
    drift_traj = jnp.vstack([drift_traj, drift_traj[-1][None, :]])   # repeat the final one, to make a consistent size. It is not used in the computation.
    dBt_traj = jnp.vstack([dBt_traj, dBt_traj[-1][None, :]])

    return traj, drift_traj, dBt_traj

def _sample_controlled_sde_fn(key, params, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, sampler_fn, x_0, shape, ts):
    """Sample trajectory using the controlled SDE."""
    drift_fn = partial(_controlled_drift_fn,
                       params,
                       model=model,
                       optional_base_drift_fn=optional_base_drift_fn,
                       guiding_drift_fn=guiding_drift_fn,
                       coeff_fn=coeff_fn,
                       sigma_fn=sigma_fn)
    traj, drift_traj, dBt_traj = _sample_sde_fn(key, drift_fn, sigma_fn, sampler_fn, x_0, shape, ts)
    return traj, drift_traj, dBt_traj


def _calculate_targets_fn(x_traj, ut_traj, dBt_traj, base_drift_fn, sigma_fn, sampler_fn, control_fn, B_ratio_fn, shape, train_config, ts):
    """
    Computes the training targets by solving the backward equation. The B_ratio function determines the weighting over the timepoints (the beta schedule).
    """
    F1 = jnp.zeros(shape)
    ts = jnp.flip(ts, axis=0)

    partial_sampler_fn = partial(sampler_fn, drift_fn=base_drift_fn, sigma_fn=sigma_fn)

    def step(carry, inputs):
        F_t_plus_dt, t_plus_dt, x_t_plus_dt, u_t_plus_dt, dB_t_plus_t = carry
        t, x, u, dB_t = inputs
        dt = t_plus_dt - t

        if train_config.get('jacobian_method', 'euler') == 'euler':
            _, vjp_fun = jax.vjp(lambda x_arg: partial_sampler_fn(x_arg, t, dB_t, dt)[0], x)
            def right_multiply_jacobian(v):
                return vjp_fun(v)[0]

        if train_config['jacobian_method'] == 'exp':
            left_jac = jax.jacfwd(lambda x_arg: base_drift_fn(x_arg, t))(x)
            right_jac = jax.jacfwd(lambda x_arg: base_drift_fn(x_arg, t_plus_dt))(x_t_plus_dt)
            exp_jac = jax.scipy.linalg.expm(dt * left_jac)
            exp_jac = jax.scipy.linalg.expm(0.5 * dt * (left_jac + right_jac)) # trapezoidal
            def right_multiply_jacobian(v):
                return v @ exp_jac
        
        B_ratio = B_ratio_fn(x, t, x_t_plus_dt, t_plus_dt)
        F_t = right_multiply_jacobian(
            B_ratio* F_t_plus_dt + (1 - B_ratio) * u_t_plus_dt
        )

        if train_config.get('STL_adjustments', False) is True:
            # Applies the STL adjustments

            # (∇u) (sigma dB_t)
            def u_fun(x_arg):
                return control_fn(xs=x_arg[None], ts=t[None])[0]  # (d,)

            sigma_val, apply_sigma_fn = sigma_fn(x, t)
            noise_term = apply_sigma_fn(dB_t)  # = sigma(x,t) @ dB_t, shape (d,)
            _, STL_adjustment_first_term = jax.jvp(u_fun, (x,), (noise_term,))

            # (u) (∇sigma dB_t)
            def _sigma_times_dB(x_):
                _, apply_sigma_fn_ = sigma_fn(x_, t)
                return apply_sigma_fn_(dB_t)

            _, grad_sigma_vjp_fun = jax.vjp(_sigma_times_dB, x)
            STL_adjustment_second_term = grad_sigma_vjp_fun(u)[0]

            STL_adjustment = STL_adjustment_first_term + STL_adjustment_second_term
            F_t = F_t - STL_adjustment

        ratio = -1 # dummy value

        return (F_t, t, x, u, dB_t), (F_t, ratio)

    rev_uts = jnp.flip(ut_traj, axis=0)
    rev_xts = jnp.flip(x_traj, axis=0)
    rev_dB_ts = jnp.flip(dBt_traj, axis=0)

    terminal_ut = rev_uts[1]
    carry = (terminal_ut, ts[1], rev_xts[1], rev_uts[1], rev_dB_ts[1]) # the first ones aren't used (corresponding to endpoint)
    inputs = (ts[2:], rev_xts[2:], rev_uts[2:], rev_dB_ts[2:])
    _, (F_ts, ratios) = jax.lax.scan(step, carry, inputs)

    F_ts = jnp.vstack([F1[None, :], terminal_ut[None,:], F_ts])
    F_ts = jnp.flip(F_ts, axis=0)

    ratios = jnp.flip(ratios, axis=0)

    return F_ts, ratios


def _control_fn(params, xs, ts, base_drift_fn, model, optional_base_drift_fn, guiding_drift, coeff_fn, sigma_fn, a_inv_fn):
    """Compute control u for given states and times."""
    bts = jax.vmap(base_drift_fn)(xs, ts)
    return _get_ut_from_params_fn(params, xs, bts, ts, model, optional_base_drift_fn, guiding_drift, coeff_fn, sigma_fn, a_inv_fn)


def _get_ut_from_params_fn(params, xs, bts, ts, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, a_inv_fn):
    """Compute control u for given states and times, given the base drift values."""
    controlled_drift_vmap = jax.vmap(partial(_controlled_drift_fn, model=model, optional_base_drift_fn=optional_base_drift_fn, guiding_drift_fn=guiding_drift_fn, coeff_fn=coeff_fn, sigma_fn=sigma_fn), in_axes=(None, 0, 0))
    controlled_drifts = controlled_drift_vmap(params, xs, ts)
    residuals = controlled_drifts - bts
    ut = jax.vmap(a_inv_fn, in_axes=(0, 0, 0))(xs, ts, residuals)  # (B, d)
    return ut

def _loss_fn(params, batch, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, a_inv_fn, train_config):
    """Pure loss function."""
    xs, bts, targets, ts = batch
    ut_preds = _get_ut_from_params_fn(params, xs, bts, ts, model, optional_base_drift_fn=optional_base_drift_fn, guiding_drift_fn=guiding_drift_fn, coeff_fn=coeff_fn, sigma_fn=sigma_fn, a_inv_fn=a_inv_fn) # (B, d)

    sq_diffs = (ut_preds - targets) ** 2
    per_sample_loss = jnp.sum(sq_diffs, axis=-1)  # (B,)

    loss_clip = train_config.get('loss_clip', 1e6)
    per_sample_loss = jnp.minimum(per_sample_loss, loss_clip)  # clip each element
    loss = jnp.mean(per_sample_loss)
    return loss

def _train_step_fn(state, batch, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, a_inv_fn, optimizer, train_config):
    """Pure training step function."""
    grad_fn = jax.value_and_grad(_loss_fn)
    loss, grads = grad_fn(state.params, batch, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, a_inv_fn, train_config)
    
    updates, new_opt_state = optimizer.update(grads, state.opt_state)

    new_params = optax.apply_updates(state.params, updates)
    new_ema_params = optax.incremental_update(new_params, state.ema_params, state.ema_rate)

    # Also keep track of the exponential moving average of the gradients, for debugging
    is_first = (state.step == 0)
    new_ema_grads = jax.lax.cond(
        is_first,
        lambda _: grads,
        lambda _: optax.incremental_update(grads, state.ema_grads, state.grad_ema_rate),
        operand=None,
    )
    
    new_state = state.replace(
        step=state.step + 1,
        params=new_params,
        ema_params=new_ema_params,
        ema_grads=new_ema_grads,
        opt_state=new_opt_state,
    )
    return new_state, loss



# ==============================================================================
# Loops that can be scanned over
# ==============================================================================

def _inner_loop_body(carry, _, train_config, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, a_inv_fn, optimizer):
    """
    Body of the inner training loop, performing multiple training steps on the same trajectories
    """
    state, traj_data, key = carry
    x_traj, bt_vals, targets, ts = traj_data

    step_key, next_key = jax.random.split(key)
    time_subkey, batch_subkey = jax.random.split(step_key)

    time_idxs = jax.random.choice(time_subkey, jnp.arange(train_config['num_steps']), shape=(train_config['train_batch_size'],), replace=True)

    batch_idxs = jax.random.choice(batch_subkey, jnp.arange(train_config['traj_batch_size']), shape=(train_config['train_batch_size'],), replace=True)

    batch = (
        x_traj[batch_idxs, time_idxs],
        bt_vals[batch_idxs, time_idxs],
        targets[batch_idxs, time_idxs],
        ts[time_idxs]
    )
    
    new_state, loss = _train_step_fn(state, batch, model, optional_base_drift_fn=optional_base_drift_fn, guiding_drift_fn=guiding_drift_fn, coeff_fn=coeff_fn, sigma_fn=sigma_fn, a_inv_fn=a_inv_fn, optimizer=optimizer, train_config=train_config)
    
    return (new_state, traj_data, next_key), loss

def _outer_loop_body(state, key, outer_step, train_config, static_bridge_data, optimizer):
    """Body of the outer training loop"""

    # =======================================================
    # Set up the required bridge objects
    # =======================================================
    model = static_bridge_data['model']
    x_0, x_T = static_bridge_data['x_0'], static_bridge_data['x_T']
    T = static_bridge_data['T']
    shape = static_bridge_data['shape']
    ts = state.ts  # use the ema version of the time discretisation

    # Optionally anneal the drift coefficient
    schedule_type = train_config.get('coeff_schedule', None)
    # Default to a constant coefficient of 1.0
    coeff = 1.0 
    # If a linear schedule is requested, calculate the coefficient
    if schedule_type == 'linear':
        total_steps = train_config['num_outer_iterations']
        # Linearly increase from near 0 to 1.0 over the outer iterations
        coeff = jnp.minimum(1.0, (outer_step + 1) / total_steps)
    if schedule_type == 'sine':
        total_steps = train_config['num_outer_iterations']
        # Cosine from near 0 to 1.0
        coeff = jnp.sin(0.5*jnp.pi * (outer_step + 1) / total_steps)
    if schedule_type == 'linear_half':
        total_steps = train_config['num_outer_iterations']
        half_steps = total_steps // 2
        # Linear increase from near 0 to 1.0 over the first half, then stay at 1.0
        coeff = jnp.where(
            outer_step < half_steps,
            (outer_step + 1) / half_steps,
            1.0
        )

    # Create a scaled version of the base drift function for this iteration
    unscaled_base_drift_fn = static_bridge_data['base_drift_fn']
    base_drift_fn = lambda x, t: coeff * unscaled_base_drift_fn(x, t)

    # construct objects used in the bridge parameterisation
    unscaled_optional_base_drift_fn = static_bridge_data['optional_base_drift_fn']
    optional_base_drift_fn = lambda x, t: coeff * unscaled_optional_base_drift_fn(x, t)
    guiding_drift_fn = static_bridge_data['guiding_drift_fn']
    coeff_fn = static_bridge_data['coeff_fn']
    base_sigma_fn = static_bridge_data['sigma_fn']
    base_a_inv_fn = static_bridge_data['a_inv_fn']
    sampler_fn = static_bridge_data['sampler_fn']

    # optionally anneal the noise level
    sigma_max_scale = train_config.get('sigma_max_scale', None)
    if sigma_max_scale is not None:
        # Decay from a scaled version of sigma to the base sigma, using a 'linear half' schedule
        total_steps = train_config['num_outer_iterations']
        half_steps = total_steps // 2
        sigma_scale = jnp.where(
            outer_step < half_steps,
            1.0 + (sigma_max_scale - 1.0) * ((outer_step + 1) / half_steps),
            sigma_max_scale
        )
        def sigma_fn(x, t):
            sigma_val, apply_sigma_fn = base_sigma_fn(x, t)
            return sigma_scale * sigma_val, lambda dB_t: sigma_scale * apply_sigma_fn(dB_t)
        
        def a_inv_fn(x, t, v):
            return base_a_inv_fn(x, t, v) / (sigma_scale ** 2)
    else:
        sigma_fn = base_sigma_fn
        a_inv_fn = base_a_inv_fn


    # ====================================================
    # Sample the trajectories, compute the controls and correct terminal control
    # ====================================================
    sample_key, inner_loop_key = jax.random.split(key)
    sample_keys = jax.random.split(sample_key, train_config['traj_batch_size'])
    vmap_sample_fn = jax.vmap(
        _sample_controlled_sde_fn, in_axes=(0, None, None, None, None, None, None, None, None, None, None)
    )
    x_traj, drift_traj, dBt_traj = vmap_sample_fn(
        sample_keys, state.ema_params, model, optional_base_drift_fn, guiding_drift_fn, coeff_fn, sigma_fn, sampler_fn, x_0, shape, ts
    )

    # calculate the controls
    vmap_base_drift = jax.vmap(jax.vmap(base_drift_fn, in_axes=(0, None)), in_axes=(1, 0))
    bt_vals = vmap_base_drift(x_traj, ts).transpose(1, 0, 2)    # (B, num_steps+1, d)
    residuals = drift_traj - bt_vals  # (B, num_steps+1, d)
    vmap_apply_inv_inner = jax.vmap(a_inv_fn, in_axes=(0, None, 0))   # (B,d) x scalar t -> (B,d)
    vmap_apply_inv = jax.vmap(vmap_apply_inv_inner, in_axes=(1, 0, 1))   # over time, feeds scalar t
    ut_traj = vmap_apply_inv(x_traj, ts, residuals).transpose(1, 0, 2)   # (B, T+1, d)

    # set terminal conditions - the final discretisation step is enforced exactly
    terminal_Brownian_drift = (x_T - x_traj[:, -2, :]) / (ts[-1] - ts[-2] + 1e-12)  # (B, d)
    # terminal_uts = (terminal_Brownian_drift - bt_vals[:, -2, :]) / sigma**2  # (B, d)
    terminal_residuals = terminal_Brownian_drift - bt_vals[:, -2, :]
    terminal_uts = jax.vmap(a_inv_fn, in_axes=(0, None, 0))(
            x_traj[:, -2, :], ts[-2], terminal_residuals
        )  # (B, d)
    # set the final two controls to the terminal control, only the penultimate is used for training
    ut_traj = ut_traj.at[:, -1, :].set(terminal_uts)
    ut_traj = ut_traj.at[:, -2, :].set(terminal_uts)

    # =========================================================
    # Compute the training targets by solving the backward equation
    # =========================================================
    if train_config.get('beta_schedule', 'average') == 'average':
        B_ratio_fn = lambda x_t, t, x_t_plus_dt, t_plus_dt: (T - t_plus_dt) / (T - t)
    if train_config['beta_schedule'] == 'endpoint':
        B_ratio_fn = lambda x_t, t, x_t_plus_dt, t_plus_dt: 1.0
    if train_config['beta_schedule'] == 'next_step':
        B_ratio_fn = lambda x_t, t, x_t_plus_dt, t_plus_dt: 0.0
    if train_config['beta_schedule'] == 'geom':
        B_ratio_fn = lambda x_t, t, x_t_plus_dt, t_plus_dt: train_config['B_ratio']
    if train_config['beta_schedule'] == 'sqrt':
        B_ratio_fn = lambda x_t, t, x_t_plus_dt, t_plus_dt: jnp.power((T - t_plus_dt) / (T - t + 1e-6), 1.5)

    control_fn = partial(_control_fn, params=state.ema_params, model=model, base_drift_fn=base_drift_fn,optional_base_drift_fn=optional_base_drift_fn, guiding_drift=guiding_drift_fn, coeff_fn=coeff_fn, sigma_fn=sigma_fn, a_inv_fn=a_inv_fn)

    vmap_targets_fn = jax.vmap(_calculate_targets_fn, in_axes=(0, 0, 0, None, None, None, None, None, None, None, None))
    training_targets, ratios = vmap_targets_fn(x_traj, ut_traj, dBt_traj, base_drift_fn, sigma_fn, sampler_fn, control_fn, B_ratio_fn, shape, train_config, ts)
    
    traj_data = (x_traj, bt_vals, training_targets, ts)

    # =========================================================
    # Run inner training loops
    # =========================================================

    initial_inner_carry = (state, traj_data, inner_loop_key)

    partial_inner_loop_body = partial(
        _inner_loop_body, train_config=train_config, model=model, optional_base_drift_fn=optional_base_drift_fn, guiding_drift_fn=guiding_drift_fn, coeff_fn=coeff_fn, sigma_fn=sigma_fn, a_inv_fn=a_inv_fn, optimizer=optimizer
    )

    # Scan over the inner loop
    final_inner_carry, losses = jax.lax.scan(
        partial_inner_loop_body,
        init=initial_inner_carry,
        xs=None,
        length=train_config['num_inner_iterations']
    )

    # Unpack the final state from the carry
    final_state, _, _ = final_inner_carry
    
    outputs = {
        'ema_params': final_state.ema_params,
        'mean_loss': jnp.mean(losses),
        'x_traj': x_traj,
    }

    return final_state, outputs



def _make_sigma_fn(sigma, dim):
    """
    Returns a callable sigma_fn(x,t) -> (apply_sigma, sigma_val)
    where:
      - apply_sigma(dB_t) computes sigma @ dB_t efficiently
    """
    # --- normalize sigma to something callable
    if callable(sigma):
        def base_sigma_val(x,t):
            return sigma(x,t)
    else:
        def base_sigma_val(x,t):
            return sigma

    def sigma_fn(x, t):
        val = jnp.asarray(base_sigma_val(x,t))
        if val.ndim == 0:
            # scalar case: multiply directly
            s = val
            def apply_sigma(dB_t):
                return s * dB_t
            return s, apply_sigma

        elif val.ndim == 1:
            # diagonal case
            if val.shape[0] != dim:
                raise ValueError(f"sigma vector has shape {val.shape}, expected ({dim},)")
            s = val
            def apply_sigma(dB_t):
                return s * dB_t  # elementwise
            return s, apply_sigma

        elif val.ndim == 2:
            # full matrix
            if val.shape != (dim, dim):
                raise ValueError(f"sigma matrix has shape {val.shape}, expected ({dim}, {dim})")
            s = val
            def apply_sigma(dB_t):
                return s @ dB_t
            return s, apply_sigma

        else:
            raise ValueError("sigma must be scalar, vector (len d), or (d,d) matrix.")
        
    def a_inv_fn(x, t, v):
        # applies (sigma sigmaᵀ)^{-1} @ v. Handles scalar, diagonal, and full matrix cases.
        val = jnp.asarray(base_sigma_val(x, t))
        if val.ndim == 0:
            return v / (val**2 + 1e-12)
        elif val.ndim == 1:
            return v / (val**2 + 1e-12)
        else:
            a_matrix = val @ val.T
            eps = 1e-8
            return jnp.linalg.solve(a_matrix + eps*jnp.eye(a_matrix.shape[0]), v)
    
    return sigma_fn, a_inv_fn
    

# ==============================================================================
# Consistency Bridge Class
# ==============================================================================

class ConsistencyBridge:

    def __init__(self, shape, x_0, x_T, base_drift_fn, sigma_fn, model, bridge_config=None, T=1.0):
        self.shape = shape
        self.x_0 = x_0
        self.x_T = x_T
        self.base_drift_fn = base_drift_fn
        self.model = model
        self.T = T

        # defaults
        defaults = {
            'base_drift': False,
            'guiding_type': 'brownian',
            'decay_coeff': True,
            'sampler': 'euler',
        }

        # merge defaults with any user-provided values
        if bridge_config is None:
            bridge_config = {}
        self.bridge_config = {**defaults, **bridge_config}

        # Set up the guiding drift function based on the configuration
        if bridge_config['base_drift'] == True:
            self.optional_base_drift_fn = base_drift_fn
        else:
            self.optional_base_drift_fn = lambda x, t: jnp.zeros_like(x)

        if bridge_config['guiding_type'] == 'linearised':
            raise NotImplementedError("The 'linearised' guiding type is not implemented in this version.")
        elif bridge_config['guiding_type'] == 'brownian':
            def Brownian_bridge_drift_fn(x, t):
                return (x_T - x) / jnp.maximum(T - t, 1e-3)
            self.guiding_drift = Brownian_bridge_drift_fn

        if bridge_config['decay_coeff'] == True:
            self.coeff_fn = lambda t: jnp.sqrt(T - t)
        else:
            self.coeff_fn = lambda t: 1.0

        if bridge_config['sampler'] == 'euler':
            self.sampler_fn = euler_maruyama_sampler

        self.sigma_fn, self.a_inv = _make_sigma_fn(sigma_fn, dim=shape[0]) # returns (sigma_val, apply_sigma)


    def train(self, key, train_config, wandb_config=None, pretrained_params=None):
        # optionally initialize wandb
        if wandb_config is not None:
            wandb.init(
                project=wandb_config.get("project", "consistency_bridges"),
                name=wandb_config.get("name", None),
                config={**train_config, **wandb_config},  # log configs
            )

        # Initialize optimizer and state
        key, init_key = jax.random.split(key)

        if pretrained_params is not None:
            init_params = pretrained_params
        else:
            init_params = self.model.init(init_key, jnp.zeros(self.shape), 0.0)
        
        optimizer = optax.chain(
            optax.clip_by_global_norm(train_config.get('grad_clip', 1.0)),
            optax.adam(learning_rate=train_config['lr'],
                          b1=train_config.get('adam_b1', 0.9),
                          b2=train_config.get('adam_b2', 0.999),
            )
        )

        initial_state = TrainState(
            step=0,
            params=init_params,
            ema_params=init_params,
            ema_grads=jax.tree_util.tree_map(jnp.zeros_like, init_params),
            opt_state=optimizer.init(init_params),
            lr=train_config['lr'],
            ema_rate=train_config['ema_rate'],
            grad_ema_rate=train_config.get('grad_ema_rate', 1.0),
            ts=jnp.linspace(0, self.T, train_config['num_steps'] + 1) # initialise at uniform time discretisation
        )
        
        # 2. Bundle all static data for the pure functions
        static_bridge_data = {
            'model': self.model,
            'x_0': self.x_0,
            'x_T': self.x_T,
            'T': self.T,
            'shape': self.shape,
            'base_drift_fn': self.base_drift_fn,
            'optional_base_drift_fn': self.optional_base_drift_fn,
            'guiding_drift_fn': self.guiding_drift,
            'coeff_fn': self.coeff_fn,
            'sigma_fn': self.sigma_fn,
            'a_inv_fn': self.a_inv,
            'sampler_fn': self.sampler_fn,
        }

        # 3. JIT compile the outer body (but not scan over it)
        outer_step_fn = jax.jit(
            partial(_outer_loop_body, train_config=train_config, static_bridge_data=static_bridge_data, optimizer=optimizer)
        )

        # 4. Manual outer loop with tqdm
        state = initial_state
        ema_params_list = [state.ema_params]
        mean_losses = []
        ema_grad_norms = []

        outer_loop_keys = jax.random.split(key, train_config['num_outer_iterations'])

        with trange(train_config['num_outer_iterations'], desc="Training", unit="step") as pbar:
            for step in pbar:
                state, outputs = outer_step_fn(state, outer_loop_keys[step], step)
                mean_losses.append(outputs['mean_loss'])

                ema_grad_norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(state.ema_grads)))
                ema_grad_norms.append(ema_grad_norm)

                pbar.set_postfix(loss=outputs['mean_loss'], ema_grad_norm=ema_grad_norm)

                # save ema params at intervals
                if (step + 1) % train_config.get('ckpt_freq', 10) == 0:
                    ema_params_list.append(state.ema_params)

                # optional wandb logging
                if wandb_config is not None:
                    if step % wandb_config.get('log_interval', 10) == 0:

                        wandb.log({
                            'train/loss': float(outputs['mean_loss']),
                            'train/ema_grad_norm': float(ema_grad_norm),
                        }, step=step)
                        
                        for plot_fn in wandb_config.get('plot_fn_lst', []):
                            fig = plot_fn(outputs['x_traj'])   # plot_fn should return a matplotlib figure or None
                            if fig is not None:
                                wandb.log({f"plot/{plot_fn.__name__}": wandb.Image(fig)}, step=step)
                                plt.close(fig)

        print("Training finished.")
        return state, ema_params_list, ema_grad_norms

    @partial(jax.jit, static_argnums=(0,2,3))
    def sample_sde(self, key, drift_fn, num_steps):
        ts = jnp.linspace(0, self.T, num_steps + 1)
        return _sample_sde_fn(key, drift_fn, self.sigma_fn, self.sampler_fn, self.x_0, self.shape, ts)
    
    @partial(jax.jit, static_argnums=(0, 3))
    def sample_controlled_sde(self, key, params, num_steps):
        ts = jnp.linspace(0, self.T, num_steps + 1)
        return _sample_controlled_sde_fn(key, params, self.model, self.optional_base_drift_fn, self.guiding_drift, self.coeff_fn, self.sigma_fn, self.sampler_fn, self.x_0, self.shape, ts)
    
    def controlled_drift(self, params, x, t):
        return _controlled_drift_fn(params, x, t, self.model, self.optional_base_drift_fn, self.guiding_drift, self.coeff_fn, self.sigma_fn)

    def control_fn(self, params, xs, ts):
        return _control_fn(params, xs, ts, self.base_drift_fn, self.model, self.optional_base_drift_fn, self.guiding_drift, self.coeff_fn, self.sigma_fn, self.a_inv_fn)