import torch
from utils.normal_kl_divergence import kl_divergence



class ADDMNIST_DPL(torch.nn.Module):
    def __init__(self, loss, nr_classes=19, pcbm=False) -> None:
        super().__init__()
        self.base_loss = loss
        self.nr_classes = nr_classes
        self.pcbm = pcbm
        self.beta = 0.1

    def forward(self, out_dict, args): 
        loss, losses = self.base_loss(out_dict, args)

        if self.pcbm:
            kl_div = 0

            mus = out_dict['MUS']
            logvars = out_dict['LOGVARS']
            for i in range(2):
                kl_div += kl_divergence(mus[i], logvars[i])
            
            loss += self.beta * kl_div
            losses.update({'kl-div': kl_div})

        return loss, losses
    
class KAND_DPL(torch.nn.Module):
    def __init__(self, loss, nr_classes=2) -> None:
        super().__init__()
        self.base_loss = loss
        self.nr_classes = nr_classes

    def forward(self, out_dict, args): 
        loss, losses = self.base_loss(out_dict, args)
        return loss, losses