class BaseConfig:
    def __init__(self, args, hidden_states=True):
        # self.mask = args.mask
        self.dataset = args.dataset
        self.dense_dim = args.dense_dim
        self.latent_dim = args.latent_dim
        self.hidden_states = hidden_states
        self.num_sampling = args.num_sampling


class BetaVAEConfig(BaseConfig):
    def __init__(self, args, in_channel=1):
        super(BetaVAEConfig, self).__init__(args)
        self.in_channel = in_channel
        self.alpha = args.alpha
        self.beta = args.beta
        self.lamb = args.lamb


class BetaTCVAEConfig(BaseConfig):
    def __init__(self, args, in_channel=1, dataset_size=0):
        super(BetaTCVAEConfig, self).__init__(args)
        self.in_channel = in_channel
        self.dataset_size = dataset_size


class FactorVAEConfig(BaseConfig):
    def __init__(self, args, in_channel=1):
        super(FactorVAEConfig, self).__init__(args)
        self.in_channel = in_channel
        self.gamma = args.gamma


class CLGVAEConfig(BaseConfig):
    def __init__(self, args, in_channel=1):
        super(CLGVAEConfig, self).__init__(args)
        self.subspace_sizes_ls = args.subspace_sizes_ls  # list of int
        self.subgroup_sizes_ls = args.subgroup_sizes_ls  # list of int
        self.no_exp = args.no_exp
        self.hy_hes = args.hy_hes
        self.hy_rec = args.hy_rec
        self.hy_commute = args.hy_commute
        self.forward_eq_prob = args.forward_eq_prob