import torch
"""Loss functions used in the paper
"Score identity Distillation: Exponentially Fast Distillation of
Pretrained Diffusion Models for One-Step Generation"."""


# ----------------------------------------------------------------------------
class SID_EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, beta_d=19.9, beta_min=0.1):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.beta_d = beta_d
        self.beta_min = beta_min

    def generator_loss(self, true_score, fake_score, x, cond=None, alpha=1.2, tmax=800, mask = None, weights = None):
# tmax changed from 800 to 100
        sigma_min = 0.002
        sigma_max = 80 #80
        rho = 7.0
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        rnd_t = torch.rand([x.shape[0], 1], device=x.device) * tmax / 1000
        sigma = (max_inv_rho + (1 - rnd_t) * (min_inv_rho - max_inv_rho)) ** rho

        n = torch.randn_like(x) * sigma



        x_real = true_score.edm.denoise(x + n, cond, sigma) # x 1dim, y0 or y1
        x_fake = fake_score.denoise(x + n, cond, sigma)

        with torch.no_grad():
            weight_factor = abs(x_real - x_fake).to(torch.float32).mean(dim=[1], keepdim=True).clip(min=0.00001)


        loss = (x_real - x_fake) * ((x_real - x) - alpha * (x_real - x_fake)) / weight_factor

        # weights from propnet
        if weights is not None:
            weights = weights.unsqueeze(1)
            loss = loss* weights  # Apply importance weighting

        if mask is not None:
            mask = mask.squeeze(1)[:, 1:3]
            num_observed = mask.sum(0)
            loss = (loss * mask).sum(0) / torch.where(num_observed > 0, num_observed,
                                                      torch.tensor(1.0, device=loss.device)) #this is 2-dim, loss for y0 and loss for y1

            batch_size = mask.shape[0]
            loss = loss * batch_size
            loss = loss.sum()
        else:
            loss = loss.sum()  # Standard mean loss if no mask


        return loss
