import torch

class MultiHeadEDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, **kwargs):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        n = torch.randn_like(images) * sigma

        D_yn = net(images + n, sigma)
        D_yn = D_yn.view(D_yn.shape[0], labels.size(1), -1, D_yn.shape[2], D_yn.shape[3])
        D_yn_cond = D_yn[torch.arange(D_yn.shape[0]), torch.argmax(labels, dim=1).long(), :, :, :]
        loss = weight * ((D_yn_cond - images) ** 2)

        return loss