from tensorflow import keras as keras
from keras import optimizers, losses, metrics
from ml_collections import config_dict
import datetime
from mask_models import *

#* Configuration
config = config_dict.ConfigDict()
config.seed = 1 #! Not used currently
config.experiment = 0 #! Not used currently

#* Training Configuration
config.train = config_dict.ConfigDict()
config.train.num_runs = 1
config.train.model = LeNet5Masked
config.train.loss_fn = losses.SparseCategoricalCrossentropy
config.train.use_reg = True
config.train.weight_decay = 1e-4
config.train.batchwise_update = True
config.train.metric = config_dict.ConfigDict()
config.train.metric.loss = metrics.SparseCategoricalCrossentropy
config.train.metric.acc = metrics.SparseCategoricalAccuracy
config.train.compute_divs = False

#* Data Configuration
config.data = config_dict.ConfigDict()
config.data.name = 'mnist'
config.data.alloc_type = 'uniform'
config.data.alloc_ratio = 4
config.data.beta = 0.5
config.data.batch_size = 128
config.data.num_validation = 5
config.data.shape = () #! TBD in Code 
config.data.num_classes = 0 #! TBD in Code 

config.beta = config_dict.ConfigDict()
config.beta.mode = 'value' 
config.beta.value = 0.5
config.beta.index = 0

#*Server Configuration 
config.server = config_dict.ConfigDict()

config.server.num_epochs = 200
config.server.update_thresh = 0.75
config.server.es = config_dict.ConfigDict()
config.server.es.enable = False
config.server.es.threshold = 0.1 #! Not used currently
config.server.es.wait = 0
config.server.es.patience = 3

#*Worker Configuration
config.worker = config_dict.ConfigDict()
config.worker.num = 10
config.worker.inact_prob = 0

config.worker.epoch = config_dict.ConfigDict()
config.worker.epoch.type = 'constant'
config.worker.epoch.mean = 3
config.worker.epoch.std = 2
config.worker.epoch.beta = 0.5
config.worker.epoch.coef = 20
config.worker.epoch.is_random = False

config.worker.tx = config_dict.ConfigDict()
config.worker.tx.fn = optimizers.Adam
config.worker.tx.lr = 0.1
config.worker.tx.moment = 0.9
config.worker.tx.lr_decay_per_server_epoch = 1.0
config.worker.tx.type_decay_per_server_epoch = 'Geometric'
config.worker.tx.lr_decay_per_worker_epoch = 1.0
config.worker.tx.type_decay_per_worker_epoch = 'Geometric'


#* Compressor Configuration
config.compressor = config_dict.ConfigDict()
config.compressor.enable = True
config.compressor.reset_aggregation = True
config.compressor.use_indiv_reference = False
config.compressor.use_posterior_prior = False
config.compressor.uplink_samples = 1
config.compressor.downlink_samples = 10
config.compressor.common_dl_prior = True
config.compressor.reuse_samples = True
config.compressor.project_kl_divergences = None # None for no projection, KL divergence in bits else
config.compressor.project_block_kl_divergences = None # None for no projection, KL divergence in bits else
config.compressor.adaptive_blocks_ul = False
config.compressor.adaptive_blocks_dl = False
config.compressor.adaptive_avg = False
config.compressor.avg_dev_factor = 2
config.compressor.kl_rate_ul = 8.0
config.compressor.kl_rate_dl = 8.0
config.compressor.split_dl = False
config.compressor.max_block_size = 512
config.compressor.block_size = 256
config.compressor.sample_size = 256

config.wandb = config_dict.ConfigDict()
config.wandb.name = '' #! TBD in Code
config.wandb.project = 'bicompfl'
config.wandb.job_type = 'bicompfl'


wandb_config = {
    'PARAMETERS': {
        'RANDOM_SEED': config.seed,
        'EXPERIMENT': config.experiment,
        
        'TRAINING_DATASET': config.data.name,
        'ALLOCATION_TYPE': config.data.alloc_type,
        'ALLOCATION_RATIO': config.data.alloc_ratio,
        'BATCH_LENGTH': config.data.batch_size,
        'DATASET_BETA': config.data.beta,
        'VALIDATION_SAMPLE_NUMBER_PER_CLASS': config.data.num_validation,
        
        'BETA_MODE': config.beta.mode,
        'BETA_INDEX': config.beta.index,
        'BETA': config.beta.value,
        
        'SERVER_MAX_ITERATION': config.server.num_epochs,
        'CLIENT_UPDATE_THRESHOLD_RATIO': config.server.update_thresh,
        'EARLY_STOPPING': config.server.es,
        'EARLY_STOPPING_RATE': config.server.es.threshold,
        
        'CLIENT_NUMBER': config.worker.num,
        'INACTIVE_PROBABILITY': config.worker.inact_prob,
        'UPDATE_MODEL_EVERY_BATCH': config.train.batchwise_update,
        'VARIED_LOCAL_ITERATION': config.worker.epoch.type,
        'LOCAL_ITERATION_TYPE': config.worker.epoch.type,
        'LOCAL_ITERATION_MEAN': config.worker.epoch.mean,
        'LOCAL_ITERATION_STD': config.worker.epoch.std,
        'LOCAL_ITERATION_BETA': config.worker.epoch.beta,
        'LOCAL_ITERATION_COEFFICIENT': config.worker.epoch.coef,
        'LOCAL_LEARNING_RATE': config.worker.tx.lr,
        'LOCAL_LEARNING_RATE_DECAY_PER_SERVER_ITERATION': config.worker.tx.lr_decay_per_server_epoch,
        'LOCAL_LEARNING_RATE_DECAY_PER_LOCAL_ITERATION': config.worker.tx.lr_decay_per_worker_epoch,
        
        'COMPRESSOR_ENABLE': config.compressor.enable,
        'COMPRESSOR_RESET_AGGREGATION': config.compressor.reset_aggregation,
        'COMPRESSOR_USE_INDIV_REFERENCE': config.compressor.use_indiv_reference,
        'COMPRESSOR_USE_POSTERIOR_PRIOR': config.compressor.use_posterior_prior,
        'UPLINK_SAMPLES': config.compressor.uplink_samples,
        'DOWNLINK_SAMPLES': config.compressor.downlink_samples,
        'COMMON_DL_PRIOR': config.compressor.common_dl_prior,
        'PROJECT_KL_DIVERGENCES': config.compressor.project_kl_divergences,
        'PROJECT_BLOCK_KL_DIVERGENCES': config.compressor.project_block_kl_divergences,
        'ADAPTIVE_BLOCKS_UL': config.compressor.adaptive_blocks_ul,
        'ADAPTIVE_BLOCKS_DL': config.compressor.adaptive_blocks_dl,
        'ADAPTIVE_AVG': config.compressor.adaptive_avg,
        'KL_RATE_UL': config.compressor.kl_rate_ul,
        'KL_RATE_DL': config.compressor.kl_rate_dl,
        'SPLIT_DL': config.compressor.split_dl,
        'MAX_BLOCK_SIZE': config.compressor.max_block_size,
        'BLOCK_SIZE': config.compressor.block_size,
        'SAMPLE_SIZE': config.compressor.sample_size,
        'REUSE_SAMPLES': config.compressor.reuse_samples,
        'AVG_DEV_FACTOR': config.compressor.avg_dev_factor
    }
}