import torch
import torch.nn.functional as F

def robust_cross_entropy(logits, target, P, reduction = 'none', eps= 1e-8):
    target = F.one_hot(target.long(), logits.shape[-1])
    pred = torch.clamp(logits.softmax(-1), min = eps, max = 1-eps)
    pred = torch.inner(pred, P)
    pred = torch.log(pred)

    if reduction == 'mean':
        return - torch.mean(torch.sum(target * pred, axis = -1))
    elif reduction == 'sum':
        return - torch.sum(torch.sum(target * pred, axis = -1))
    elif reduction == 'none':
        return -torch.sum(target * pred, axis = -1)

class MultiHeadNoisyEDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, **kwargs):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels, P):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        n = torch.randn_like(images) * sigma

        D_yn = net(images + n, sigma)
        loss = weight * ((D_yn - images.repeat(1,labels.size(1),1,1)) ** 2)
        loss = loss.view(loss.shape[0], labels.size(1), -1, loss.shape[2], loss.shape[3])
        
        logits = -loss.mean(dim=(2,3,4))
        logits_exp = torch.exp(logits)
        loss = robust_cross_entropy(logits, torch.argmax(labels, dim=1).long(), P, reduction="none")
        loss_reg = - torch.log(logits_exp.sum(dim=1) + 1e-8)

        return loss.mean(), loss_reg.mean()