import torch
import torch.nn as nn


class CrossEntropyThr(nn.Module):
    def __init__(self, **kwargs):
        super(CrossEntropyThr, self).__init__()
        self.thr = kwargs['alpha']
        self.ce = nn.CrossEntropyLoss(reduce='mean')

    def forward(self, outputs, targets):
        return torch.abs(self.ce(outputs, targets) - self.thr)

    def update(self, thr=0.0):
        self.thr = thr


if __name__ == '__main__':
    a = torch.randn(16, 100)
    print(a.size())
