import torch


def cross_entropy(*_):
    def criterion(logits, y, *_):
        xe = torch.nn.functional.cross_entropy(logits, y, reduction='none')
        loss = xe.mean()
        return loss

    return criterion
