from dl.src.info.base import BaseInfo
from dl.src.info.betavae import BetaVAEInfo

class CMCS_GT_Info(BaseInfo):
    def __init__(self, args, **kwargs):
        super(CMCS_GT_Info, self).__init__(args, **kwargs)

        self.alpha = args.alpha
        self.beta = args.beta
        self.gamma = args.gamma
        self.nth_root = args.nth_root
        self.elbo = kwargs["elbo"]
        self.reconst = kwargs["reconst"]
        self.kld = kwargs["kld"]
        self.code_loss = kwargs["code_loss"]
        self.canonical_loss = kwargs["canonical_loss"]
        self.decoder_equiv = kwargs["decoder_equiv"]

        self.beta_vae = kwargs["beta_vae"]
        self.fvm = kwargs["factor_disent"]
        self.mig = kwargs["mig"]
        self.sap = kwargs["sap"]
        self.dci_disent = kwargs["dci_disent"]
        self.dci_completness = kwargs["dci_comple"]


class CMCS_Super_Info(CMCS_GT_Info):
    def __init__(self, args, **kwargs):
        super(CMCS_Super_Info, self).__init__(args, **kwargs)
        self.label_loss = kwargs["label_loss"]
        self.label_acc = kwargs["label_acc"]

class CMCS_SemiSuper_Info(CMCS_GT_Info):
    def __init__(self, args, **kwargs):
        super(CMCS_SemiSuper_Info, self).__init__(args, **kwargs)
        self.label_loss = kwargs["label_loss"]
        self.label_kld = kwargs["label_kld"]
        self.label_acc = kwargs["label_acc"]

class CMCS_UnSuper_Info(BetaVAEInfo):
    def __init__(self, args, **kwargs):
        super(CMCS_UnSuper_Info, self).__init__(args, **kwargs)
        self.gamma = args.gamma
        self.abs_diff = kwargs["abs_diff"]
        self.prior_list = args.prior_list






