# Copyright (c) 2024, Mingyuan Zhou. All rights reserved.
#
# This work is licensed under APACHE LICENSE, VERSION 2.0
# You should have received a copy of the license along with this
# work. If not, see https://www.apache.org/licenses/LICENSE-2.0.txt

import torch
from torch_utils import persistence
import torch.nn as nn


"""Loss functions used in the paper
"Adversarial Score Identity Distillation: Rapidly Surpassing the Teacher in One Step"."""

#----------------------------------------------------------------------------
@persistence.persistent_class
class FSIM_EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,beta_d=19.9, beta_min=0.1):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.beta_d = beta_d
        self.beta_min = beta_min
        
    def generator_share_encoder_loss(self, true_score, fake_score, images, labels=None, augment_pipe=None,alpha=1.2,tmax = 800,return_y_D=True):
                
        sigma_min = 0.002
        sigma_max = 80
        rho = 7.0
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        
        rnd_t = torch.rand([images.shape[0], 1, 1, 1], device=images.device)*tmax/1000
        sigma = (max_inv_rho + (1-rnd_t) * (min_inv_rho - max_inv_rho)) ** rho
        
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
        n = torch.randn_like(y) * sigma

        y_real = true_score(y + n, sigma, labels, augment_labels=augment_labels)
        
        if return_y_D:
            y_fake,y_D = fake_score(y+n,sigma, labels, augment_labels=augment_labels,return_flag='encoder_decoder')
        else:
            y_fake = fake_score(y + n, sigma, labels, augment_labels=augment_labels)

        nan_mask_y = torch.isnan(y).flatten(start_dim=1).any(dim=1)
        nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
        nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
        nan_mask = nan_mask_y | nan_mask_y_real | nan_mask_y_fake
        if return_y_D:
            nan_mask = nan_mask | torch.isnan(y_D).flatten(start_dim=1).any(dim=1)

        # Check if there are any NaN values present
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y = y[non_nan_mask]
            y_real = y_real[non_nan_mask]
            y_fake = y_fake[non_nan_mask]
            weight = weight[non_nan_mask]
            if return_y_D:
                y_D = y_D[non_nan_mask]
        
        with torch.no_grad():
            weight_factor = abs(y - y_real).to(torch.float32).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)

        loss = (y_real-y_fake)*( (y_real-y)-alpha*(y_real-y_fake) )/weight_factor 
        
        if return_y_D:
            y_D_labels = torch.ones_like(y_D)
            bce_loss = nn.BCEWithLogitsLoss()
            loss_gan = bce_loss(y_D.clamp(-10,10),y_D_labels).to(torch.float32)/weight_factor
            return loss, loss_gan
        else:
            return loss
            
    def fakescore_discriminator_share_encoder_loss(self,fake_score, images, labels=None, augment_pipe=None,real_images=None,true_score=None,alpha=None):
        
        
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2


        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
        
        n = torch.randn_like(y) * sigma
        y_fake,logit_fake = fake_score(y + n, sigma, labels, augment_labels=augment_labels,return_flag='encoder_decoder')
        
    
        with torch.no_grad():
            weight_factor = abs(y - y_fake).to(torch.float32).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)

        y_real_images,label_temp = augment_pipe(real_images) if augment_pipe is not None else (real_images, torch.zeros(images.shape[0], 9).to(images.device))

        logit_real = fake_score(y_real_images+n,sigma,labels, augment_labels=label_temp,return_flag='encoder')
        
        
        nan_mask = torch.isnan(y).flatten(start_dim=1).any(dim=1) | torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)

        nan_mask = nan_mask | torch.isnan(logit_fake).flatten(start_dim=1).any(dim=1) | torch.isnan(logit_real).flatten(start_dim=1).any(dim=1)
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            logit_fake = logit_fake[non_nan_mask]
            logit_real = logit_real[non_nan_mask]
            weight_factor = weight_factor[non_nan_mask]
            
            y = y[non_nan_mask]
            y_fake = y_fake[non_nan_mask]
            weight = weight[non_nan_mask]
            

        loss_fake_score = weight * ((y_fake - y) ** 2)         
        
                
        real_labels = torch.ones_like(logit_real)
        fake_labels = torch.zeros_like(logit_fake)
        bce_loss = nn.BCEWithLogitsLoss()
        loss_real = bce_loss(logit_real, real_labels)
        loss_fake = bce_loss(logit_fake, fake_labels)
        loss_D = weight*(loss_real + loss_fake) / 2     
    

        return loss_fake_score, loss_D
   
    def generator_loss(self, true_score, fake_score, images, labels=None, augment_pipe=None,alpha=1.2,tmax = 800):
                
        sigma_min = 0.002
        sigma_max = 80
        rho = 7.0
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        rnd_t = torch.rand([images.shape[0], 1, 1, 1], device=images.device)*tmax/1000
        sigma = (max_inv_rho + (1-rnd_t) * (min_inv_rho - max_inv_rho)) ** rho        
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
        n = torch.randn_like(y) * sigma
        y_real = true_score(y + n, sigma, labels, augment_labels=augment_labels)
        y_fake = fake_score(y + n, sigma, labels, augment_labels=augment_labels)
        
        nan_mask_y = torch.isnan(y).flatten(start_dim=1).any(dim=1)
        nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
        nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
        nan_mask = nan_mask_y | nan_mask_y_real | nan_mask_y_fake

        # Check if there are any NaN values present
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y = y[non_nan_mask]
            y_real = y_real[non_nan_mask]
            y_fake = y_fake[non_nan_mask]
    
        with torch.no_grad():
            weight_factor = abs(y - y_real).to(torch.float32).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
        loss = (y_real-y_fake)*( (y_real-y)-alpha*(y_real-y_fake) )/weight_factor 
        return loss
        
    def __call__(self, fake_score, images, labels=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        y_fake = fake_score(y + n, sigma, labels, augment_labels=augment_labels)
        nan_mask = torch.isnan(y).flatten(start_dim=1).any(dim=1) | torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y_fake = y_fake[non_nan_mask]
            y = y[non_nan_mask]
            weight=weight[non_nan_mask]
        loss = weight * ((y_fake - y) ** 2)
        return loss
    
    

    
@persistence.persistent_class
class EDM2Loss:
    def __init__(self, P_mean=-0.4, P_std=1.0, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        noise = torch.randn_like(images) * sigma
        denoised, logvar = net(images + noise, sigma, labels, return_logvar=True)
        loss = (weight / logvar.exp()) * ((denoised - images) ** 2) + logvar
        return loss
    
    
    
@persistence.persistent_class
class FSIM_EDM_2_Loss:
    def __init__(self, P_mean=-0.4, P_std=1.0, sigma_data=0.5):
        
        #P_mean=-1.2
        #P_std=1.2
        
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

        
    def generator_share_encoder_loss(self, true_score, fake_score, images, labels=None, augment_pipe=None,alpha=1.2,tmax = 800,return_y_D=True,return_logvar=True):
                
        sigma_min = 0.002
        sigma_max = 80
        rho = 7.0
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        
        rnd_t = torch.rand([images.shape[0], 1, 1, 1], device=images.device)*tmax/1000
        sigma = (max_inv_rho + (1-rnd_t) * (min_inv_rho - max_inv_rho)) ** rho
        
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        
        #y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
        y=images
        n = torch.randn_like(y) * sigma

        y_real = true_score(y + n, sigma, labels)
        
        if return_logvar:
            if return_y_D:
                y_fake,y_D,logvar = fake_score(y+n,sigma, labels,return_flag='encoder_decoder',return_logvar=True)
            else:
                y_fake,logvar = fake_score(y + n, sigma, labels,return_logvar=True)

        else:
            if return_y_D:
                y_fake,y_D = fake_score(y+n,sigma, labels,return_flag='encoder_decoder')
            else:
                y_fake = fake_score(y + n, sigma, labels)

        nan_mask_y = torch.isnan(y).flatten(start_dim=1).any(dim=1)
        nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
        nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
        nan_mask = nan_mask_y | nan_mask_y_real | nan_mask_y_fake
        
        if return_logvar:
            nan_mask = nan_mask | torch.isnan(logvar).flatten(start_dim=1).any(dim=1)
        
        if return_y_D:
            nan_mask = nan_mask | torch.isnan(y_D).flatten(start_dim=1).any(dim=1)
        
        
        # Check if there are any NaN values present
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y = y[non_nan_mask]
            y_real = y_real[non_nan_mask]
            y_fake = y_fake[non_nan_mask]
            weight = weight[non_nan_mask]
            if return_logvar:
                logvar=logvar[non_nan_mask]
            if return_y_D:
                y_D=y_D[non_nan_mask]
        
        with torch.no_grad():
            weight_factor = abs(y - y_real).to(torch.float32).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
        
        loss = (y_real-y_fake)*( (y_real-y)-alpha*(y_real-y_fake) )/weight_factor 
        if return_logvar:
            loss = loss/logvar.exp()
        
        if return_y_D:
            y_D_labels = torch.ones_like(y_D)
            bce_loss = nn.BCEWithLogitsLoss()
            loss_gan = bce_loss(y_D.clamp(-10,10),y_D_labels).to(torch.float32)/weight_factor
            if return_logvar:
                loss_gan = loss_gan/logvar.exp()
            ###FIXME maybe y_real and y_fake and y should devide logvar as well?
            return loss, loss_gan, y_real, y_fake, y
        else:
            return loss
            
    def fakescore_discriminator_share_encoder_loss(self,fake_score, images, labels=None, augment_pipe=None,real_images=None,true_score=None,alpha=None,return_logvar=True):
        
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y=images
        n = torch.randn_like(y) * sigma
        if return_logvar:
            y_fake,logit_fake,logvar = fake_score(y + n, sigma, labels, return_logvar=True,return_flag='encoder_decoder')
        else:
            y_fake,logit_fake = fake_score(y + n, sigma, labels,return_flag='encoder_decoder')
        nan_mask = torch.isnan(y).flatten(start_dim=1).any(dim=1) | torch.isnan(y_fake).flatten(start_dim=1).any(dim=1) | torch.isnan(logit_fake).flatten(start_dim=1).any(dim=1)
        
        
        y_real_images,label_temp = augment_pipe(real_images) if augment_pipe is not None else (real_images, torch.zeros(images.shape[0], 9).to(images.device))
        #logit_real,logvar = fake_score(y_real_images+n,sigma,labels, return_flag='encoder',return_logvar=True)
        logit_real = fake_score(y_real_images+n,sigma,labels, return_flag='encoder')
        #logvar_exp=logvar.exp()
        nan_mask = nan_mask | torch.isnan(logit_real).flatten(start_dim=1).any(dim=1)
        
        
        if return_logvar:
            nan_mask = nan_mask|torch.isnan(logvar).flatten(start_dim=1).any(dim=1)
        
        
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y_fake = y_fake[non_nan_mask]
            y = y[non_nan_mask]
            weight=weight[non_nan_mask]
            logit_fake = logit_fake[non_nan_mask]
            logit_real = logit_real[non_nan_mask]
            if return_logvar:
                logvar=logvar[non_nan_mask]
        #
        if return_logvar:
            logvar_exp=logvar.exp()
        
        
        if return_logvar:
            loss_fake_score = (weight / logvar_exp) * ((y_fake - y) ** 2) + logvar
        else:
            loss_fake_score = weight * ((y_fake - y) ** 2)   
        
        real_labels = torch.ones_like(logit_real)
        fake_labels = torch.zeros_like(logit_fake)
        bce_loss = nn.BCEWithLogitsLoss()
        loss_real = bce_loss(logit_real, real_labels)
        loss_fake = bce_loss(logit_fake, fake_labels)
        loss_D = weight*(loss_real + loss_fake) / 2     
        if return_logvar:
            loss_D = loss_D/logvar_exp
    

        return loss_fake_score, loss_D
        
        
#         rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
#         sigma = (rnd_normal * self.P_std + self.P_mean).exp()
#         weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2


#         #y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
#         y=images
#         n = torch.randn_like(y) * sigma
#         if return_logvar:
#             y_fake,logit_fake,logvar = fake_score(y + n, sigma, labels, return_flag='encoder_decoder',return_logvar=True)
#         else:
#             y_fake,logit_fake = fake_score(y + n, sigma, labels, return_flag='encoder_decoder')
        
    
# #         with torch.no_grad():
# #             weight_factor = abs(y - y_fake).to(torch.float32).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)

#         y_real_images,label_temp = augment_pipe(real_images) if augment_pipe is not None else (real_images, torch.zeros(images.shape[0], 9).to(images.device))

#         #logit_real = fake_score(y_real_images+n,sigma,labels, return_flag='encoder')
#         logit_real,logvar = fake_score(y_real_images+n,sigma,labels, return_flag='encoder',return_logvar=True)
        
        
#         nan_mask = torch.isnan(y).flatten(start_dim=1).any(dim=1) | torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)

#         nan_mask = nan_mask | torch.isnan(logit_fake).flatten(start_dim=1).any(dim=1) | torch.isnan(logit_real).flatten(start_dim=1).any(dim=1)
#         if return_logvar:
#             nan_mask = nan_mask | torch.isnan(logvar).flatten(start_dim=1).any(dim=1)
#         if nan_mask.any():
#             # Invert the nan_mask to get a mask of samples without NaNs
#             non_nan_mask = ~nan_mask
#             # Filter out samples with NaNs from y_real and y_fake
#             logit_fake = logit_fake[non_nan_mask]
#             logit_real = logit_real[non_nan_mask]
#             #weight_factor = weight_factor[non_nan_mask]
            
#             y = y[non_nan_mask]
#             y_fake = y_fake[non_nan_mask]
#             weight = weight[non_nan_mask]
#             if return_logvar:
#                 logvar=logvar[non_nan_mask]
            
#         logvar_exp=logvar.exp()
        
#         if return_logvar:
#             loss_fake_score = (weight / logvar_exp) * ((y_fake - y) ** 2) + logvar
#         else:
#             loss_fake_score = weight * ((y_fake - y) ** 2)   
        
                
#         real_labels = torch.ones_like(logit_real)
#         fake_labels = torch.zeros_like(logit_fake)
#         bce_loss = nn.BCEWithLogitsLoss()
#         loss_real = bce_loss(logit_real, real_labels)
#         loss_fake = bce_loss(logit_fake, fake_labels)
#         loss_D = weight*(loss_real + loss_fake) / 2     
#         if return_logvar:
#             loss_D = loss_D/logvar_exp
    

#         return loss_fake_score, loss_D
   
    def generator_loss(self, true_score, fake_score, images, labels, augment_pipe=None,alpha=1.2,tmax = 800,return_logvar=True):
                
        sigma_min = 0.002
        sigma_max = 80
        rho = 7.0
        min_inv_rho = sigma_min ** (1 / rho)
        max_inv_rho = sigma_max ** (1 / rho)
        rnd_t = torch.rand([images.shape[0], 1, 1, 1], device=images.device)*tmax/1000
        sigma = (max_inv_rho + (1-rnd_t) * (min_inv_rho - max_inv_rho)) ** rho        
        #y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
        y=images
        n = torch.randn_like(y) * sigma
        #y_real  = true_score(y + n, sigma, labels)
        if return_logvar:
            y_real,logvar  = true_score(y + n, sigma, labels,return_logvar=True)
        else:
            y_real  = true_score(y + n, sigma, labels)
        
        if return_logvar:
            y_fake,logvar= fake_score(y + n, sigma, labels,return_logvar=True)
        else:
            y_fake= fake_score(y + n, sigma, labels)
        
        ###TO investigate: using logvar from true_score or fake_score? the previous version use logvar from true_score here and it works ok
        
        
        nan_mask_y = torch.isnan(y).flatten(start_dim=1).any(dim=1)
        nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
        nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
        nan_mask = nan_mask_y | nan_mask_y_real | nan_mask_y_fake
        
        if return_logvar:
            nan_mask = nan_mask | torch.isnan(logvar).flatten(start_dim=1).any(dim=1)

        # Check if there are any NaN values present
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y = y[non_nan_mask]
            y_real = y_real[non_nan_mask]
            y_fake = y_fake[non_nan_mask]
            if return_logvar:
                logvar = logvar[non_nan_mask]
        with torch.no_grad():
            weight_factor = abs(y - y_real).to(torch.float32).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
        loss = (y_real-y_fake)*( (y_real-y)-alpha*(y_real-y_fake) )/weight_factor
        
        if return_logvar:
            loss = loss/logvar.exp() 
        
        #/ logvar.exp()  + logvar
        
        
        #denoised, logvar = net(images + noise, sigma, labels, return_logvar=True)
        #loss = (weight / logvar.exp()) * ((denoised - images) ** 2) + logvar
        
        return loss
        
    def __call__(self, fake_score, images, labels,return_logvar=True):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y=images
        n = torch.randn_like(y) * sigma
        if return_logvar:
            y_fake,logvar = fake_score(y + n, sigma, labels, return_logvar=True)
        else:
            y_fake = fake_score(y + n, sigma, labels)
        nan_mask = torch.isnan(y).flatten(start_dim=1).any(dim=1) | torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
        if return_logvar:
            nan_mask = nan_mask | torch.isnan(logvar).flatten(start_dim=1).any(dim=1)
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            y_fake = y_fake[non_nan_mask]
            y = y[non_nan_mask]
            weight=weight[non_nan_mask]
            if return_logvar:
                logvar = logvar[non_nan_mask]
        #
        if return_logvar:
            loss = (weight / logvar.exp()) * ((y_fake - y) ** 2) + logvar
        else:
            loss = weight * ((y_fake - y) ** 2)
        return loss