from flax import nnx
from jax import Array, jit
from jax.random import fold_in, normal, randint

from offline.diffusion.modules.unconditional import (
    unconditional_diffusion_loss_fn,
)
from offline.diffusion.repaint.modules import InpaintPolicy
from offline.modules.base import TrainState


@jit
def train_step(
    batch: Array,
    diffusion_steps: int,
    graphdef: nnx.GraphDef[TrainState[InpaintPolicy]],
    graphstate: nnx.GraphState | nnx.VariableState,
    step: int,
    train_noise_key: Array,
    train_time_key: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    noise = normal(fold_in(train_noise_key, step), batch.shape)
    timesteps = randint(
        fold_in(train_time_key, step),
        batch.shape[:-1] + (1,),
        minval=0,
        maxval=diffusion_steps,
    )
    noisy_samples = train_state.model.diffusion.add_noise(
        noise=noise, samples=batch, timesteps=timesteps
    )
    value_and_grad_fn = nnx.value_and_grad(unconditional_diffusion_loss_fn)
    loss, grads = value_and_grad_fn(
        train_state.model.diffusion.noise_predictor,
        noise=noise,
        noisy_samples=noisy_samples,
        timesteps=timesteps,
    )
    train_state.optimizer.update(grads=grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss/diffusion": loss}
