class JCGConfig:
    def __init__(self, args, hidden_states=True):
        # self.mask = args.mask
        self.dataset = args.dataset
        self.dense_dim = args.dense_dim
        self.latent_dim = args.latent_dim
        self.hidden_states = hidden_states
        self.num_sampling = args.num_sampling
        self.c_rot = args.c_rot
        self.g_rot = args.g_rot
        self.n_flip = args.n_flip
        self.temperature = args.temperature
        self.normalization = args.normalization
        self.soft =args.soft
        self.beta = args.beta