from flax import nnx
from jax import Array, jit, numpy as jnp

from offline.modules.actor.utils import gaussian_log_likelihood
from offline.modules.base import TrainState
from offline.bc.modules import BCPolicy


def behavior_cloning_loss_fn(
    policy: BCPolicy, actions: Array, observations: Array
):
    means, stds = policy.actor(observations)
    log_likelihoods = gaussian_log_likelihood(
        means=means, samples=actions, stds=stds
    )
    return -jnp.mean(log_likelihoods)


@jit
def train_step(
    actions: Array,
    graphdef: nnx.GraphDef[TrainState[BCPolicy]],
    graphstate: nnx.GraphState | nnx.VariableState,
    observations: Array,
):
    train_state = nnx.merge(graphdef, graphstate)
    value_and_grad_fn = nnx.value_and_grad(behavior_cloning_loss_fn)
    loss, grads = value_and_grad_fn(train_state.model, actions, observations)
    train_state.optimizer.update(grads)
    _, graphstate = nnx.split(train_state)
    return graphstate, {"loss": loss}
