from torch import Tensor
from torch import nn

from constants import DEVICE


def cross_entropy_loss(y: Tensor, target: Tensor) -> Tensor:
    """ cross entropy loss function

    :param y: the logits of DNN, shape is (B, N)
    :param target: the one-hot of real label, shape is (B, N)
    """
    crossEntropyLoss = nn.CrossEntropyLoss().to(DEVICE)
    loss = crossEntropyLoss(y, target)
    return loss
