import torch

def cross_entropy_loss(inputs, target, eps=1e-8):
    loss = -target * torch.log(inputs + eps)
    return loss

class MultiHeadCLSEDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, **kwargs):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        n = torch.randn_like(images) * sigma

        D_yn = net(images + n, sigma)
        loss = weight * ((D_yn - images.repeat(1,labels.size(1),1,1)) ** 2)
        loss = loss.view(loss.shape[0], labels.size(1), -1, loss.shape[2], loss.shape[3])
        
        logits = -loss.mean(dim=(2,3,4))
        logits_exp = torch.exp(logits)
        confidence =  logits_exp / logits_exp.sum(dim=1, keepdim=True)
        loss = cross_entropy_loss(confidence, labels).sum(dim=1)
        loss_reg = - torch.log(logits_exp.sum(dim=1) + 1e-8)
        return loss.mean(), loss_reg.mean()