import torch
import torch.nn as nn


class EECELoss(nn.Module):
    def __init__(self):
        super(EECELoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, outputs, ys, outputs_victim):
        ensemble_outputs = (outputs+outputs_victim)/2
        loss = self.criterion(ensemble_outputs, ys)

        return loss
