import torch
import torch.nn.functional as F

def get_lord_error_fn(fn, params, state, ord):
    @jit
    def lord_error(X, Y):
        errors = nn.softmax(fn(params, state, X)) - Y
        scores = jnp.linalg.norm(errors, ord=ord, axis=-1)
        return scores

    np_lord_error = lambda X, Y: np.array(lord_error(X, Y))
    return np_lord_error



