from flax import nnx
from jax import Array, jit
from jax.random import fold_in, normal, randint

from offline.diffusion.modules.conditional import conditional_diffusion_loss_fn
from offline.diffusion.modules import DiffusionPolicy
from offline.modules.base import TrainState


@jit
def train_step(
    actions: Array,
    graphdef: nnx.GraphDef[TrainState[DiffusionPolicy]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
    step: int,
    train_noise_key: Array,
    train_time_key: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    noise = normal(fold_in(train_noise_key, step), actions.shape)
    timesteps = randint(
        fold_in(train_time_key, step),
        actions.shape[:-1] + (1,),
        minval=0,
        maxval=train_state.model.diffusion.diffusion_steps,
    )
    noisy_actions = train_state.model.diffusion.add_noise(
        noise=noise, samples=actions, timesteps=timesteps
    )
    value_and_grad_fn = nnx.value_and_grad(conditional_diffusion_loss_fn)
    loss, grads = value_and_grad_fn(
        train_state.model.diffusion.noise_predictor,
        conditions=observations,
        noise=noise,
        noisy_samples=noisy_actions,
        timesteps=timesteps,
    )
    train_state.optimizer.update(grads=grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/diffusion": loss}
