import torch
from LorentzMACE.tools.torch_geometric import Batch

from LorentzMACE.tools import TensorDict


class ClassificationLoss(torch.nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
        return (self.loss(pred['energy'], ref['signal']))

    def __repr__(self):
        return (f'{self.__class__.__name__}')
