import os
import csv




class BaseInfo:
    def __init__(self, args, **kwargs):
        self.file_dir = args.results_file
        self.opt = args.optimizer
        self.epoch = args.num_epoch
        self.lr = args.lr_rate
        self.seed = args.seed
        self.wd = args.weight_decay
        self.batch = args.train_batch_size
        self.latent = args.latent_dim

        # self.elbo = kwargs["elbo"]
        # self.reconst = kwargs["reconst"]
        # self.kld = kwargs["kld"]
        #
        # self.beta_vae = kwargs["beta_vae"]
        # self.fvm = kwargs["factor_disent"]
        # self.mig = kwargs["mig"]
        # self.sap = kwargs["sap"]
        # self.dci_disent = kwargs["dci_disent"]
        # self.dci_completness = kwargs["dci_comple"]

    def write_results(self):
        file_exists = os.path.isfile(self.file_dir)
        fieldnames = [str(key) for key in self.__dict__]

        with open(self.file_dir, "a+", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            if not file_exists:
                writer.writeheader()
            writer.writerow(self.__dict__)
        return