from transformers import EvalPrediction
import torch
import torch.nn as nn
import numpy as np

def classification_accuracy(res: EvalPrediction):
    logits, labels = res
    pred = np.argmax(logits, axis=-1)
    return {"accuracy": (pred == labels).mean()}

def seq_cross_entropy(res: EvalPrediction):
    logits, labels = res
    _logits = torch.from_numpy(logits).float().permute(0, 2, 1) 
    _labels = torch.from_numpy(labels)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(_logits, _labels)
    return {"cross_entropy": loss.item()}