             
DEFAULTS = {
    # ----- General training options
            'dataset':'CIFAR10'
            ,'batch_size' : 64
            ,'seed':None # Seed for RNG
            ,'max_samples':None # None for full dataset
            ,'optim': 'SGD'
            ,'loss':'CrossEntropy'
            # ----- Model and it's sparse initialization
            ,'model':'ResNet18'
            ,'model_init_sparsity' : 0.9
            ,'conv_group':True
            ,'r':[10,10,10] # Rescaling after masking
    # ----- Global optimizer parameters
            ,'num_epochs':100
            ,'learning_rate':0.1
            ,'lr_scheduler':'CosineAnnealing'
            ,'momentum': 0.9
    # ----- Proximal step parameters
            ,'reg':'None'
            ,'delta':1.0
            ,'lambda1':0.001
            ,'lambda0':0.001
    # ----- Parameters for sparse learning
            ,'full_update_frequency':100
            ,'full_update_duration':1
            ,'full_update_mode':'step' # Whether frequency refers to training step or full epoch
            ,'eps': 1e-3
            ,'kappa': 0.5
}

class Config():
    def __init__(self,verbose=True, **kwargs):
        """
        Declare configuration for the training. Use values from DEFAULT when 
        unspecified. Options not relevant to the selected optimizer will not
        be set.
        """
        self._verbose = verbose # Property will be neither displayed or saved
        
        self._set_basic(**kwargs)
        self._set_optimizer(**kwargs)
    
        self._populate_remaining(**kwargs)
    
    def _populate(self, options_to_set, **kwargs):
        for option in options_to_set:
            value = kwargs.pop(option, DEFAULTS[option])
            setattr(self, option, value)
    
    def _set_basic(self, **kwargs):
        options_to_set= ['dataset','batch_size', 'seed', 'max_samples'
                         ,'model' ,'model_init_sparsity','conv_group','r', 
                         'loss', 'optim', 'num_epochs', 
                         'learning_rate', 'lr_scheduler'] 
        self._populate(options_to_set, **kwargs)
    
    def _set_optimizer(self, **kwargs):
        options_to_set = []
        # Methods with momentum
        if self.optim in ['SGD', 'LinBreg', 'LinBregSparse']:
            options_to_set += ['momentum']
            
        # Methods that perform a proximal step
        if self.optim in ['GradSkip','LinBregSparse', 'ProxSGD', 'AdaBreg', 
                            'AdaBregSparse', 'LinBreg', 'LinBregSparseML', 'debug']:
            options_to_set += ['reg', 'delta', 'lambda1','lambda0']
        
        # Methods with frequency-based sparse update criteria
        if self.optim in ['GradSkip', 'LinBregSparse', 'SGD-sparse', 'AdaBregSparse', 'debug']:
            options_to_set += ['full_update_frequency','full_update_duration',
                              'full_update_mode']
        
        # Methods with inequality-based sparse update criteria
        if self.optim in ['LinBregSparseML']:
            options_to_set += ['eps', 'kappa']
        
        self._populate(options_to_set, **kwargs)
        
        
    def _populate_remaining(self, **kwargs):
        # Set attributes not listed above/in DEFAULTS
        already_considered = DEFAULTS.keys()
        for key, value in kwargs.items():
            if key not in already_considered:
                if self._verbose:
                    print(f' (!) Setting a non-default attribute "{key}"')
                setattr(self, key, value)                
                
    def __repr__(self):
        out_str = 'Configurations with attributes:'
        for name, value in self.__dict__.items():
            if not name == '_verbose':
                out_str += f'\n{name} : {value}'
        return out_str

    
    def save(self, name='configurations'):
        """
        Save as an easily readible txt file. Note: Not in a format to recreate
        configurations
        """
        with open(name+".txt", "w") as f:
            for k, v in vars(self).items():
                if not k == '_verbose':
                    f.write(f"{k}: {repr(v)}\n")


