import torch
import torch.nn as nn


class ECELoss(nn.Module):
    def __init__(self):
        super(ECELoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, outputs, ys, outputs_victim):
        outputs_victim_arr = outputs_victim.detach()
        ensemble_outputs = (outputs+outputs_victim_arr)/2
        loss = self.criterion(ensemble_outputs, ys)

        return loss
