import math
from jax import numpy as jnp
from latte_trans.evals.base import Evaluator
from latte_trans.evals.losses import cross_entropy_loss


def acc_class(loss_fn, output, labels):
    logits = output["logits"]
    loss = loss_fn(logits=logits, target=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        "loss": loss,
        "accuracy": accuracy,
    }
    return metrics


class ClassificEvaluator(Evaluator):
    def __init__(self, val_data, data_collator, config) -> None:
        super().__init__(val_data, data_collator, config)
        self._batchnorm = config.batchnorm

    def compute_metrics(self, output, labels):
        return acc_class(cross_entropy_loss, output, labels)
