import jax.numpy as jnp

from egxc.utils.typing import (
    Float1,
    FloatAx3,
)


def force_loss(target: FloatAx3, prediction: FloatAx3) -> Float1:
    """Mean squared error on nuclear forces."""
    out = jnp.square(target - prediction).sum(axis=-1)
    return jnp.mean(out)
