from src.info.base import BaseInfo


class BetaVAEInfo(BaseInfo):
    def __init__(self, args, **kwargs):
        super(BetaVAEInfo, self).__init__(args, **kwargs)
        self.beta = args.beta

        self.elbo = kwargs["elbo"]
        self.obj = kwargs['obj']
        self.reconst = kwargs["reconst"]
        self.kld = kwargs["kld"]

        self.beta_vae = kwargs["beta_vae"]
        self.fvm = kwargs["factor_disent"]
        self.mig = kwargs["mig"]
        self.sap = kwargs["sap"]
        self.dci_disent = kwargs["dci_disent"]
        self.dci_completness = kwargs["dci_comple"]


class CGE_BetaVAEInfo(BetaVAEInfo):
    def __init__(self, args, **kwargs):
        super(CGE_BetaVAEInfo, self).__init__(args, **kwargs)
        self.c_rot = args.c_rot
        self.g_rot = args.g_rot
        self.n_flip = args.n_flip
        self.temperature = args.temperature
        self.normalization = args.normalization
        self.soft =args.soft

def write_info(args, results):
    info = CGE_BetaVAEInfo(args, **results)
    info.write_results()
    return