import torch
from enum import Enum

def compute_anomaly_accuracy(pred, labels):
    acc = (pred.argmax(-1) == labels.argmax(-1)).float().mean()
    return acc
