import jax.numpy as jnp


def accuracy(pred, target, mask=None):

    correct = jnp.argmax(pred, axis=-1) == target

    if mask is not None:
        return jnp.sum(mask * correct) / jnp.sum(mask)
    else:
        return jnp.mean(correct)


def accuracy_metrics(pred, target, mask=None):

    correct = jnp.argmax(pred, axis=-1) == target
    last_unmasked = jnp.max(jnp.arange(len(mask)) * mask)

    if mask is not None:
        acc = jnp.sum(mask * correct) / jnp.sum(mask)
    else:
        acc = jnp.mean(correct)

    return {"acc": acc, "acc_all": acc == 1.0, "acc_first": correct[0], "acc_last": correct[last_unmasked]}
