import torch

__all__ = [
    "MSE_accuracy",
    "bernoulli_accuracy",
    "class_accuracy",
    "top1_accuracy",
    "top5_accuracy",
]


def MSE_accuracy(pred_y, data_y):
    pred_y = torch.sigmoid(pred_y)
    return torch.mean(torch.square(pred_y - data_y))


def bernoulli_accuracy(z, t):
    err = z - t  ## sigmoid already applied in nn.Sequential()
    return torch.mean(torch.square(err))


def class_accuracy(pred, labels):
    # need to be tested
    pred = pred.argmax(dim=1)
    return 100 * torch.sum(pred == labels) / labels.shape[0]


def top1_accuracy(pred, labels):
    _, pred_args = torch.topk(pred, 1, dim=1)
    correct_preds = torch.eq(labels[:, None, ...], pred_args).any(dim=1)
    return 100 * torch.sum(correct_preds) / labels.shape[0]


def top5_accuracy(pred, labels):
    _, pred_args = torch.topk(pred, 5, dim=1)
    correct_preds = torch.eq(labels[:, None, ...], pred_args).any(dim=1)
    return 100 * torch.sum(correct_preds) / labels.shape[0]
