import torch
import torch.nn as nn
import torch.nn.functional as F

class CwRLoss(nn.Module):
    def __init__(self, c, ls):
        super(CwRLoss, self).__init__()
        self.c = c
        self.ls = ls

    def forward(self, output, target):
        label_rej = 0 * target + 7
        loss_fn = nn.CrossEntropyLoss(label_smoothing=self.ls)
        loss1 = loss_fn(output, target)
        loss2 = loss_fn(output, label_rej)
        loss = (loss1 + ((1 - self.c) * loss2))
        return loss