import torch
import torch.nn as nn
class Hybrid_Loss(nn.Module):
    def __init__(self, w1_pre , w2_pre,alpha, beta, lambda1=0.9,lambda2=0.1):
        super(Hybrid_Loss,self).__init__()
        self.w1_pre = w1_pre
        self.w2_pre = w2_pre
        self.alpha  = alpha
        self.beta   = beta
        self.lambda1= lambda1
        self.lambda2= lambda2
    def forward(self,loss_G,loss_S,loss_T):
        loss_C = self.alpha*loss_S + self.beta*loss_T
        with torch.no_grad():
            w_together = (self.w1_pre * torch.exp(self.lamada1 * loss_G) + self.w2_pre * torch.exp(self.lamada1 * loss_C))
            self.w1_pre, self.w2_pre = self.w1_pre * torch.exp(self.lamada1 * loss_G) / w_together, self.w2_pre * torch.exp(self.lamada1 * loss_C) / w_together
            w_together2 = (self.alpha * torch.exp(self.lamada2 * loss_S) + self.beta * torch.exp(self.lamada2 * loss_T))
            self.alpha, self.beta = self.alpha * torch.exp(self.lamada2 * loss_S) / w_together2, self.beta * torch.exp(self.lamada2 * loss_T) / w_together2

        loss = self.w1_pre * loss_G + self.w2_pre * (self.alpha * loss_S + self.beta * loss_T)

        return loss