import torch
import torch.nn as nn


class RCELoss(nn.Module):
    def __init__(self):
        super(RCELoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, outputs, ys):
        loss = self.criterion(outputs, ys)
        num_class = outputs.shape[1]
        for i in range(num_class):
            ys_temp = torch.zeros_like(ys)+i
            loss -= self.criterion(outputs, ys_temp)/num_class

        return loss
