import torch
import torch.nn.functional as F

def cross_entropy_loss(inputs, target, eps=1e-8):
    # input = torch.clamp(input, eps, 1 - eps)
    loss = -target * torch.log(inputs + eps)
    return loss

def propotion_loss(inputs, target, eps=1e-8):
    loss = cross_entropy_loss(inputs, target, eps=eps) 
    loss = torch.sum(loss, dim=-1)
    return loss

class MultiHeadLLPEDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, **kwargs):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels_prop):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        n = torch.randn_like(images) * sigma

        (b, i, ch, w, h) = images.size()
        images = images.reshape(-1, ch, w, h)
        n = n.reshape(-1, ch, w, h)

        sigma = sigma.repeat(1, i, 1, 1, 1).reshape(-1, 1, 1, 1)
        weight = weight.repeat(1, i, 1, 1, 1).reshape(-1, 1, 1, 1)

        D_yn = net(images + n, sigma)
        loss = weight * ((D_yn - images.repeat(1,labels_prop.size(1),1,1)) ** 2)
        loss = loss.view(loss.shape[0], labels_prop.size(1), -1, loss.shape[2], loss.shape[3])
        
        logits = -loss.mean(dim=(2,3,4))
        logits_exp = torch.exp(logits)
        confidence = logits_exp / logits_exp.sum(dim=1, keepdim=True)
        confidence = confidence.reshape(b, i, -1)
        pred_prop = confidence.mean(dim=1)

        loss_llp = propotion_loss(pred_prop, labels_prop)
        loss_reg = - torch.log(logits_exp.sum(dim=1) + 1e-8)

        return loss_llp.mean(), loss_reg.mean()