import torch


__all__ = [
    "MSE_accuracy",
    "bernoulli_accuracy",
    "class_accuracy",
]


def MSE_accuracy(z, t):
    err = z - t
    return torch.mean(torch.square(err))


def bernoulli_accuracy(z, t):
    err = z - t  ## sigmoid already applied in nn.Sequential()
    return torch.mean(torch.square(err))


def class_accuracy(z, t):
    # need to be tested
    y = torch.argmax(z)
    return torch.mean(float(y == t))
