import torch

EPS=torch.finfo(torch.float32).eps
MIN=torch.finfo(torch.float32).tiny

from torch_utils import persistence

from torch.nn.functional import logsigmoid

#----------------------------------------------------------------------------

@persistence.persistent_class
class VPLoss:
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

#----------------------------------------------------------------------------


@persistence.persistent_class
class VELoss:
    def __init__(self, sigma_min=0.02, sigma_max=100):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------
# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).

@persistence.persistent_class
class EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, images, labels=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss
    
#color channel 
@persistence.persistent_class
class GDDLoss:
    def __init__(self, eta=None, sigmoid_start=None, sigmoid_end=None, sigmoid_power=None, Scale=None, Shift=None, T=500, epsilon_t=1e-5,lossType='KLUB'):
        self.eta = eta
        self.eta1 = eta  
        self.eta2 = eta 

        self.sigmoid_start = sigmoid_start  
        self.sigmoid_end = sigmoid_end  
        self.sigmoid_power = sigmoid_power  
        self.Scale = Scale  
        self.Shift = Shift  
        self.T = T  
        self.lossType = lossType  
        self.epsilon_t = epsilon_t  
        self.min = torch.finfo(torch.float32).tiny  
        self.eps = torch.finfo(torch.float32).eps  
    def __call__(self, net, images, labels=None, augment_pipe=None):
        
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        rnd_position = 1 + rnd_uniform * (self.epsilon_t - 1)
        logit_alpha = self.sigmoid_start + (self.sigmoid_end-self.sigmoid_start) * (rnd_position**self.sigmoid_power)
        rnd_position_previous = rnd_position*0.95
        logit_alpha_previous = self.sigmoid_start + (self.sigmoid_end-self.sigmoid_start) * (rnd_position_previous**self.sigmoid_power)

        alpha = logit_alpha.sigmoid()
        alpha_previous = logit_alpha_previous.sigmoid()
        delta  = (logit_alpha_previous.to(torch.float64).sigmoid()-logit_alpha.to(torch.float64).sigmoid()).to(torch.float32)
        eta= torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta
        eta1 = torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta1
        eta2 = torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta2
        # Prepare x0, xt
        x0, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        x0 = ((x0+1.0)/2.0).clamp(0,1) * self.Scale + self.Shift
        #x0 = (((x0+1.0)/2.0) * self.Scale + self.Shift).clamp(0,1)
    
        x0_r = x0[:, 0:1, :, :]  
        x0_g = x0[:, 1:2, :, :]  
        x0_b = x0[:, 2:3, :, :]  

        
        log_u1 = self.log_gamma( (self.eta * alpha * x0_r).to(torch.float32))
        log_v1 = self.log_gamma( (self.eta - self.eta * alpha * x0_r).to(torch.float32))

        logit_x_t1 = (log_u1 - log_v1).to(images.device) 
        
        log_u2 = self.log_gamma( (self.eta1 * alpha * x0_g).to(torch.float32))
        log_v2 = self.log_gamma( (self.eta1 - self.eta1 * alpha * x0_g).to(torch.float32))

        logit_x_t2 = (log_u2 - log_v2).to(images.device) 
       
        log_u3 = self.log_gamma( (self.eta2 * alpha * x0_b).to(torch.float32))
        log_v3 = self.log_gamma( (self.eta2 - self.eta2 * alpha * x0_b).to(torch.float32))
        
        logit_x_t3 = (log_u3 - log_v3).to(images.device) 
        
        logit_x_t = torch.cat([logit_x_t1, logit_x_t2, logit_x_t3], dim=1)

        #------
        xmin = self.Shift
        xmax = self.Shift + self.Scale
        xmean = self.Shift+self.Scale/2.0
        E1 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1 * alpha * xmax).lgamma() - (self.eta1 * alpha * xmin).lgamma())
        E2 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1-self.eta1 * alpha * xmin).lgamma() - (self.eta1-self.eta1 * alpha * xmax).lgamma())
        E_logit_x_t =  E1 - E2

        V1 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1 * alpha * xmax).digamma() - (self.eta1 * alpha * xmin).digamma())
        V2 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1-self.eta1 * alpha * xmin).digamma() - (self.eta1-self.eta1 * alpha * xmax).digamma())
        
        if 1:
           
            grids = (torch.arange(0,101,device=images.device)/100)*self.Scale+self.Shift
            alpha_x = alpha[:,:,0,0]*grids.unsqueeze(0)
        
            
            V3 =  ((self.eta1 * alpha_x).digamma())**2
            V3[:,0] = (V3[:,0]+V3[:,-1])/2
            V3 = V3[:,:-1]
            V3 = (V3.mean(dim=1).unsqueeze(1).unsqueeze(2).unsqueeze(3)- E1**2).clamp(0)  
            
            V4 = ((self.eta1 - self.eta1*alpha_x).digamma())**2
            V4[:,0] = (V4[:,0]+V4[:,-1])/2
            V4 = V4[:,:-1]
            V4 = (V4.mean(dim=1).unsqueeze(1).unsqueeze(2).unsqueeze(3)- E2**2).clamp(0)
            

        std_logit_x_t = (V1+V2+V3+V4).sqrt()
        #--------
        
        logit_x0_hat = net((logit_x_t-E_logit_x_t)/std_logit_x_t, logit_alpha,labels, augment_labels=augment_labels)
        x0_hat=torch.sigmoid(logit_x0_hat)* self.Scale + self.Shift
        x0_hat_r = x0_hat[:, 0:1, :, :]  
        x0_hat_g = x0_hat[:, 1:2, :, :]  
        x0_hat_b = x0_hat[:, 2:3, :, :]    

        loss1 = self.compute_loss(x0_r, x0_hat_r, alpha, alpha_previous, eta, delta)
        loss2 = self.compute_loss(x0_g, x0_hat_g, alpha, alpha_previous, eta1, delta)
        loss3 = self.compute_loss(x0_b, x0_hat_b, alpha, alpha_previous, eta2, delta)
        loss=(loss1+loss2+loss3)/3.0


        if self.lossType=="HPD1":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)
        if self.lossType=="PHD":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)
        if self.lossType=="HPD2":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)
        if self.lossType=="BOTH":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)     
        return loss


    def compute_loss(self, x0, x0_hat, alpha, alpha_previous, eta, delta):
        alpha_p = eta*delta*x0 
        beta_p = eta-eta*alpha_previous*x0
        alpha_q = eta*delta*x0_hat
        beta_q  = eta-eta*alpha_previous*x0_hat 

        _alpha_p = eta*alpha*x0 
        _beta_p  = eta-eta*alpha*x0
        _alpha_q = eta*alpha*x0_hat
        _beta_q  = eta-eta*alpha*x0_hat 

        KLUB_conditional = (self.KL_gamma(alpha_q,alpha_p).clamp(0)\
                                + self.KL_gamma(beta_q,beta_p).clamp(0)\
                                - self.KL_gamma(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0)
        KLUB_marginal = (self.KL_gamma(_alpha_q,_alpha_p).clamp(0)\
                            + self.KL_gamma(_beta_q,_beta_p).clamp(0)\
                            - self.KL_gamma(_alpha_q+_beta_q,_alpha_p+_beta_p).clamp(0)).clamp(0)

        HPD=self.HPD(alpha_p,alpha_q,beta_p,beta_q)

        HPD1 = (self.HPD1(alpha_q,alpha_p).clamp(0)+ self.HPD1(beta_q,beta_p).clamp(0)- self.HPD1(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0) 
        #HPD2=(self.HPD2(alpha_q,alpha_p).clamp(0)+ self.HPD2(beta_q,beta_p).clamp(0)- self.HPD2(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0)
        PHD  = (self.PHD(alpha_q,alpha_p).clamp(0)+ self.PHD(beta_q,beta_p).clamp(0)- self.PHD(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0) 
        
        loss_dict = {
            'KLUB': (0.95 * KLUB_conditional + 0.05 * KLUB_marginal),
            "HPD": HPD,
            "HPD1":HPD1,
            "PHD":PHD,
            
        }

        if self.lossType not in loss_dict:
            raise NotImplementedError("Loss type not implemented")
        loss = loss_dict[self.lossType]
        
        return loss_dict[self.lossType]
        #return loss
    def HPD(self,alpha_p,alpha_q,beta_p,beta_q):
        a = 2.0
        b = a/(a-1)
        number1=1.0
        F1=(torch.lgamma((a * alpha_p)) + torch.lgamma((a * beta_p))-torch.lgamma(a*(alpha_p+beta_p))).clamp(0)
        F2=(torch.lgamma((b * alpha_q)) + torch.lgamma((b * beta_q))-torch.lgamma(b*(alpha_q+beta_q))).clamp(0)
        F3=(torch.lgamma((alpha_p+alpha_q-number1))+torch.lgamma((beta_p+beta_q-number1) ) -torch.lgamma((alpha_p+alpha_q-number1+beta_p+beta_q-number1))).clamp(0)
        hpd=1/a *F1+ 1/b *F2-F3
        return -hpd

    def HPD1(self,alpha_p,alpha_q):
        
        a = 2.0
        b = a/(a-1)

        F1_1 = torch.sum(torch.lgamma(a * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p), dim=1, keepdim=True))
        F2_1 = torch.sum(torch.lgamma(b * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q), dim=1, keepdim=True))
        F3_1 = torch.sum(torch.lgamma(alpha_p + alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q), dim=1, keepdim=True))

        hd1 = 1/a * F1_1 + 1/b * F2_1 - F3_1

        a = b
        b = a/(a-1)
        F1_2 = torch.sum(torch.lgamma(a * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p), dim=1, keepdim=True))
        F2_2 = torch.sum(torch.lgamma(b * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q), dim=1, keepdim=True))
        F3_2 = torch.sum(torch.lgamma(alpha_p + alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q), dim=1, keepdim=True))

        hd2 = 1/a * F1_2 + 1/b * F2_2 - F3_2

        hd = 1/2 * (hd1+hd2)
        return hd 
    
    def HPD2(self,alpha_p,alpha_q):
        
        c = 32
        a = 2.0
        b = a/(a-1)
        yibeta = torch.ones((1, c), device=alpha_p.device)

        F1_1 = torch.sum(torch.lgamma(a * alpha_p + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p + yibeta), dim=1, keepdim=True))
        F2_1 = torch.sum(torch.lgamma(b * alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q + yibeta), dim=1, keepdim=True))
        F3_1 = torch.sum(torch.lgamma(alpha_p + alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q + yibeta), dim=1, keepdim=True))

        hd1 = 1/a * F1_1 + 1/b * F2_1 - F3_1

        a = b
        b = a/(a-1)
        F1_2 = torch.sum(torch.lgamma(a * alpha_p + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p + yibeta), dim=1, keepdim=True))
        F2_2 = torch.sum(torch.lgamma(b * alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q + yibeta), dim=1, keepdim=True))
        F3_2 = torch.sum(torch.lgamma(alpha_p + alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q + yibeta), dim=1, keepdim=True))

        hd2 = 1/a * F1_2 + 1/b * F2_2 - F3_2

        hd = 1/2 * (hd1+hd2)
        return hd  
          
    def PHD(self,alpha_p,alpha_q):
        a = 2.0
        b = a/(a-1)
        g = 1.0

        term1 = (g / a) * alpha_p
        term2 = (g / b) * alpha_q 
        sum_term = term1 + term2

        F1_1 = torch.sum(torch.lgamma(g * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_p), dim=1, keepdim=True))
        F2_1 = torch.sum(torch.lgamma(g * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_q), dim=1, keepdim=True))
        F3_1 = torch.sum(torch.lgamma(sum_term), dim=1, keepdim=True) - torch.lgamma(torch.sum(sum_term, dim=1, keepdim=True))

        hd1 = 1/a * F1_1 + 1/b * F2_1 - F3_1

        a = b
        b = a/(a-1)
        F1_2 = torch.sum(torch.lgamma(g * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_p), dim=1, keepdim=True))
        F2_2 = torch.sum(torch.lgamma(g * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_q), dim=1, keepdim=True))
        F3_2 = torch.sum(torch.lgamma(sum_term), dim=1, keepdim=True) - torch.lgamma(torch.sum(sum_term, dim=1, keepdim=True))

        hd2 = 1/a * F1_2 + 1/b * F2_2 - F3_2

        hd = 1/2 * (hd1+hd2)
        return hd    
    
    def log_gamma(self, alpha):
        #return torch.log(torch._standard_gamma(alpha).clamp(MIN))
        return torch.log(torch._standard_gamma(alpha.to(torch.float32)).clamp(MIN))
        #return torch.log(torch._standard_gamma(alpha.to(torch.float64))).to(torch.float32)


    def KL_gamma(self, alpha_p, alpha_q, beta_p=None, beta_q=None):  
        KL = (alpha_p-alpha_q)*torch.digamma(alpha_p)-torch.lgamma(alpha_p)+torch.lgamma(alpha_q)
        if beta_p is not None and beta_q is not None:
            KL = KL + alpha_q*(torch.log(beta_p)-torch.log(beta_q))+alpha_p*(beta_q/beta_p-1.0)  
        return KL

#spatial pixel patching   
@persistence.persistent_class
class GDD2Loss:
    def __init__(self, eta=None, sigmoid_start=None, sigmoid_end=None, sigmoid_power=None, Scale=None, Shift=None, T=500, epsilon_t=1e-5,lossType='KLUB'):
        self.eta = eta
        self.eta1 = eta  
        self.eta2 = eta 
        self.eta3 = eta
        self.sigmoid_start = sigmoid_start  
        self.sigmoid_end = sigmoid_end  
        self.sigmoid_power = sigmoid_power  
        self.Scale = Scale  
        self.Shift = Shift  
        self.T = T  
        self.lossType = lossType  
        self.epsilon_t = epsilon_t  
        self.min = torch.finfo(torch.float32).tiny  
        self.eps = torch.finfo(torch.float32).eps  
    def __call__(self, net, images, labels=None, augment_pipe=None):
        
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        rnd_position = 1 + rnd_uniform * (self.epsilon_t - 1)
        logit_alpha = self.sigmoid_start + (self.sigmoid_end-self.sigmoid_start) * (rnd_position**self.sigmoid_power)
        rnd_position_previous = rnd_position*0.95
        logit_alpha_previous = self.sigmoid_start + (self.sigmoid_end-self.sigmoid_start) * (rnd_position_previous**self.sigmoid_power)

        alpha = logit_alpha.sigmoid()
        alpha_previous = logit_alpha_previous.sigmoid()
        delta  = (logit_alpha_previous.to(torch.float64).sigmoid()-logit_alpha.to(torch.float64).sigmoid()).to(torch.float32)
        eta= torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta
        eta1 = torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta1
        eta2 = torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta2
        eta3 = torch.ones([images.shape[0], 1, 1, 1], device=images.device) * self.eta3
        # Prepare x0, xt
        x0, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        x0 = ((x0+1.0)/2.0).clamp(0,1) * self.Scale + self.Shift
        #x0 = (((x0+1.0)/2.0) * self.Scale + self.Shift).clamp(0,1)
        B, C, H, W = x0.shape
        h2, w2 = H // 2, W // 2

        x0_1 = x0[:, :, 0:h2, 0:w2]      # 左上
        x0_2 = x0[:, :, 0:h2, w2:W]      # 右上
        x0_3 = x0[:, :, h2:H, 0:w2]      # 左下
        x0_4 = x0[:, :, h2:H, w2:W]      # 右下

        
        log_u1 = self.log_gamma( (self.eta * alpha * x0_1).to(torch.float32))
        log_v1 = self.log_gamma( (self.eta - self.eta * alpha * x0_1).to(torch.float32))

        logit_x_t1 = (log_u1 - log_v1).to(images.device) 
        
        log_u2 = self.log_gamma( (self.eta1 * alpha * x0_2).to(torch.float32))
        log_v2 = self.log_gamma( (self.eta1 - self.eta1 * alpha * x0_2).to(torch.float32))

        logit_x_t2 = (log_u2 - log_v2).to(images.device) 
       
        log_u3 = self.log_gamma( (self.eta2 * alpha * x0_3).to(torch.float32))
        log_v3 = self.log_gamma( (self.eta2 - self.eta2 * alpha * x0_3).to(torch.float32))
        
        logit_x_t3 = (log_u3 - log_v3).to(images.device) 
        
        log_u4 = self.log_gamma( (self.eta3 * alpha * x0_4).to(torch.float32))
        log_v4 = self.log_gamma( (self.eta3 - self.eta3 * alpha * x0_4).to(torch.float32))
        
        logit_x_t4 = (log_u4 - log_v4).to(images.device) 

        logit_top=torch.cat([logit_x_t1, logit_x_t2], dim=3) 
        logit_bottom=torch.cat([logit_x_t3, logit_x_t4], dim=3) 
        logit_x_t = torch.cat([logit_top, logit_bottom], dim=2)


        #------
        xmin = self.Shift
        xmax = self.Shift + self.Scale
        xmean = self.Shift+self.Scale/2.0
        E1 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1 * alpha * xmax).lgamma() - (self.eta1 * alpha * xmin).lgamma())
        E2 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1-self.eta1 * alpha * xmin).lgamma() - (self.eta1-self.eta1 * alpha * xmax).lgamma())
        E_logit_x_t =  E1 - E2

        V1 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1 * alpha * xmax).digamma() - (self.eta1 * alpha * xmin).digamma())
        V2 = 1.0/(self.eta1*alpha*self.Scale)*((self.eta1-self.eta1 * alpha * xmin).digamma() - (self.eta1-self.eta1 * alpha * xmax).digamma())
        
        if 1:
           
            grids = (torch.arange(0,101,device=images.device)/100)*self.Scale+self.Shift
            alpha_x = alpha[:,:,0,0]*grids.unsqueeze(0)
        
            
            V3 =  ((self.eta1 * alpha_x).digamma())**2
            V3[:,0] = (V3[:,0]+V3[:,-1])/2
            V3 = V3[:,:-1]
            V3 = (V3.mean(dim=1).unsqueeze(1).unsqueeze(2).unsqueeze(3)- E1**2).clamp(0)  
            
            V4 = ((self.eta1 - self.eta1*alpha_x).digamma())**2
            V4[:,0] = (V4[:,0]+V4[:,-1])/2
            V4 = V4[:,:-1]
            V4 = (V4.mean(dim=1).unsqueeze(1).unsqueeze(2).unsqueeze(3)- E2**2).clamp(0)
            

        std_logit_x_t = (V1+V2+V3+V4).sqrt()
        #--------
        
        logit_x0_hat = net((logit_x_t-E_logit_x_t)/std_logit_x_t, logit_alpha,labels, augment_labels=augment_labels)
        x0_hat=torch.sigmoid(logit_x0_hat)* self.Scale + self.Shift
        x0_hat_1 = x0_hat[:, :, 0:h2, 0:w2] 
        x0_hat_2 = x0_hat[:, :, 0:h2, w2:W]   
        x0_hat_3 = x0_hat[:, :, h2:H, 0:w2]    
        x0_hat_4 = x0_hat[:, :, h2:H, w2:W]  

        loss1 = self.compute_loss(x0_1, x0_hat_1, alpha, alpha_previous, eta, delta)
        loss2 = self.compute_loss(x0_2, x0_hat_2, alpha, alpha_previous, eta1, delta)
        loss3 = self.compute_loss(x0_3, x0_hat_3, alpha, alpha_previous, eta2, delta)
        loss4 = self.compute_loss(x0_4, x0_hat_4, alpha, alpha_previous, eta3, delta)
        loss=(loss1+loss2+loss3+loss4)/4.0


        if self.lossType=="HPD1":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)
        if self.lossType=="PHD":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)
        if self.lossType=="HPD2":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)
        if self.lossType=="BOTH":
            loss=self.compute_loss(x0,x0_hat,alpha,alpha_previous,eta,delta)     
        return loss


    def compute_loss(self, x0, x0_hat, alpha, alpha_previous, eta, delta):
        alpha_p = eta*delta*x0 
        beta_p = eta-eta*alpha_previous*x0
        alpha_q = eta*delta*x0_hat
        beta_q  = eta-eta*alpha_previous*x0_hat 

        _alpha_p = eta*alpha*x0 
        _beta_p  = eta-eta*alpha*x0
        _alpha_q = eta*alpha*x0_hat
        _beta_q  = eta-eta*alpha*x0_hat 

        KLUB_conditional = (self.KL_gamma(alpha_q,alpha_p).clamp(0)\
                                + self.KL_gamma(beta_q,beta_p).clamp(0)\
                                - self.KL_gamma(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0)
        KLUB_marginal = (self.KL_gamma(_alpha_q,_alpha_p).clamp(0)\
                            + self.KL_gamma(_beta_q,_beta_p).clamp(0)\
                            - self.KL_gamma(_alpha_q+_beta_q,_alpha_p+_beta_p).clamp(0)).clamp(0)

        HPD=self.HPD(alpha_p,alpha_q,beta_p,beta_q)

        HPD1 = (self.HPD1(alpha_q,alpha_p).clamp(0)+ self.HPD1(beta_q,beta_p).clamp(0)- self.HPD1(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0) 
        #HPD2=(self.HPD2(alpha_q,alpha_p).clamp(0)+ self.HPD2(beta_q,beta_p).clamp(0)- self.HPD2(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0)
        PHD  = (self.PHD(alpha_q,alpha_p).clamp(0)+ self.PHD(beta_q,beta_p).clamp(0)- self.PHD(alpha_q+beta_q,alpha_p+beta_p).clamp(0)).clamp(0) 
        
        loss_dict = {
            'KLUB': (0.97 * KLUB_conditional + 0.03 * KLUB_marginal),
            "HPD": HPD,
            "HPD1":HPD1,
            "PHD":PHD,
            
        }

        if self.lossType not in loss_dict:
            raise NotImplementedError("Loss type not implemented")
        loss = loss_dict[self.lossType]
        
        return loss_dict[self.lossType]
        #return loss
    def HPD(self,alpha_p,alpha_q,beta_p,beta_q):
        a = 2.0
        b = a/(a-1)
        number1=1.0
        F1=(torch.lgamma((a * alpha_p)) + torch.lgamma((a * beta_p))-torch.lgamma(a*(alpha_p+beta_p))).clamp(0)
        F2=(torch.lgamma((b * alpha_q)) + torch.lgamma((b * beta_q))-torch.lgamma(b*(alpha_q+beta_q))).clamp(0)
        F3=(torch.lgamma((alpha_p+alpha_q-number1))+torch.lgamma((beta_p+beta_q-number1) ) -torch.lgamma((alpha_p+alpha_q-number1+beta_p+beta_q-number1))).clamp(0)
        hpd=1/a *F1+ 1/b *F2-F3
        return -hpd

    def HPD1(self,alpha_p,alpha_q):
        
        a = 2.0
        b = a/(a-1)

        F1_1 = torch.sum(torch.lgamma(a * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p), dim=1, keepdim=True))
        F2_1 = torch.sum(torch.lgamma(b * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q), dim=1, keepdim=True))
        F3_1 = torch.sum(torch.lgamma(alpha_p + alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q), dim=1, keepdim=True))

        hd1 = 1/a * F1_1 + 1/b * F2_1 - F3_1

        a = b
        b = a/(a-1)
        F1_2 = torch.sum(torch.lgamma(a * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p), dim=1, keepdim=True))
        F2_2 = torch.sum(torch.lgamma(b * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q), dim=1, keepdim=True))
        F3_2 = torch.sum(torch.lgamma(alpha_p + alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q), dim=1, keepdim=True))

        hd2 = 1/a * F1_2 + 1/b * F2_2 - F3_2

        hd = 1/2 * (hd1+hd2)
        return hd 
    
    def HPD2(self,alpha_p,alpha_q):
        
        c = 32
        a = 2.0
        b = a/(a-1)
        yibeta = torch.ones((1, c), device=alpha_p.device)

        F1_1 = torch.sum(torch.lgamma(a * alpha_p + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p + yibeta), dim=1, keepdim=True))
        F2_1 = torch.sum(torch.lgamma(b * alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q + yibeta), dim=1, keepdim=True))
        F3_1 = torch.sum(torch.lgamma(alpha_p + alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q + yibeta), dim=1, keepdim=True))

        hd1 = 1/a * F1_1 + 1/b * F2_1 - F3_1

        a = b
        b = a/(a-1)
        F1_2 = torch.sum(torch.lgamma(a * alpha_p + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((a * alpha_p + yibeta), dim=1, keepdim=True))
        F2_2 = torch.sum(torch.lgamma(b * alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((b * alpha_q + yibeta), dim=1, keepdim=True))
        F3_2 = torch.sum(torch.lgamma(alpha_p + alpha_q + yibeta), dim=1, keepdim=True) - torch.lgamma(torch.sum((alpha_p + alpha_q + yibeta), dim=1, keepdim=True))

        hd2 = 1/a * F1_2 + 1/b * F2_2 - F3_2

        hd = 1/2 * (hd1+hd2)
        return hd  
          
    def PHD(self,alpha_p,alpha_q):
        a = 2.0
        b = a/(a-1)
        g = 1.0

        term1 = (g / a) * alpha_p
        term2 = (g / b) * alpha_q 
        sum_term = term1 + term2

        F1_1 = torch.sum(torch.lgamma(g * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_p), dim=1, keepdim=True))
        F2_1 = torch.sum(torch.lgamma(g * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_q), dim=1, keepdim=True))
        F3_1 = torch.sum(torch.lgamma(sum_term), dim=1, keepdim=True) - torch.lgamma(torch.sum(sum_term, dim=1, keepdim=True))

        hd1 = 1/a * F1_1 + 1/b * F2_1 - F3_1

        a = b
        b = a/(a-1)
        F1_2 = torch.sum(torch.lgamma(g * alpha_p), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_p), dim=1, keepdim=True))
        F2_2 = torch.sum(torch.lgamma(g * alpha_q), dim=1, keepdim=True) - torch.lgamma(torch.sum((g * alpha_q), dim=1, keepdim=True))
        F3_2 = torch.sum(torch.lgamma(sum_term), dim=1, keepdim=True) - torch.lgamma(torch.sum(sum_term, dim=1, keepdim=True))

        hd2 = 1/a * F1_2 + 1/b * F2_2 - F3_2

        hd = 1/2 * (hd1+hd2)
        return hd    
    
    def log_gamma(self, alpha):
        #return torch.log(torch._standard_gamma(alpha).clamp(MIN))
        return torch.log(torch._standard_gamma(alpha.to(torch.float32)).clamp(MIN))
        #return torch.log(torch._standard_gamma(alpha.to(torch.float64))).to(torch.float32)


    def KL_gamma(self, alpha_p, alpha_q, beta_p=None, beta_q=None):  
        KL = (alpha_p-alpha_q)*torch.digamma(alpha_p)-torch.lgamma(alpha_p)+torch.lgamma(alpha_q)
        if beta_p is not None and beta_q is not None:
            KL = KL + alpha_q*(torch.log(beta_p)-torch.log(beta_q))+alpha_p*(beta_q/beta_p-1.0)  
        return KL
