import torch


def brier_loss(softmaxes, labels):
    one_hot = torch.zeros_like(softmaxes)
    one_hot[torch.arange(labels.size(0)),labels] = 1.0
    bs = torch.mean(torch.sum((one_hot - softmaxes)**2, dim=1))
    return bs
