from tensorflow import keras as keras
from keras import optimizers, losses, metrics
from ml_collections import config_dict
import datetime
from models import *

# * Configuration
config = config_dict.ConfigDict()
config.seed = 1  # ! Not used currently
config.experiment = 0  # ! Not used currently

# * Training Configuration
config.train = config_dict.ConfigDict()
config.train.num_runs = 3
config.train.model = lenet5
config.train.loss_fn = losses.SparseCategoricalCrossentropy
config.train.use_reg = True
config.train.weight_decay = 1e-4
config.train.batchwise_update = True
config.train.metric = config_dict.ConfigDict()
config.train.metric.loss = metrics.SparseCategoricalCrossentropy
config.train.metric.acc = metrics.SparseCategoricalAccuracy
config.train.compute_divs = False

# * Data Configuration
config.data = config_dict.ConfigDict()
config.data.name = 'mnist'
config.data.alloc_type = 'dirichlet'
config.data.alloc_ratio = 4
config.data.beta = 0.5
config.data.batch_size = 128
config.data.num_validation = 5
config.data.shape = ()  # ! TBD in Code
config.data.num_classes = 0  # ! TBD in Code

config.beta = config_dict.ConfigDict()
config.beta.mode = 'value'
config.beta.value = 0.5
config.beta.index = 0

# *Server Configuration
config.server = config_dict.ConfigDict()

config.server.num_epochs = 200
config.server.update_thresh = 0.75
config.server.es = config_dict.ConfigDict()
config.server.es.enable = False
config.server.es.threshold = 0.1  # ! Not used currently
config.server.es.wait = 0
config.server.es.patience = 3
config.server.lr = 0.005

# *Worker Configuration
config.worker = config_dict.ConfigDict()
config.worker.num = 10
config.worker.inact_prob = 0

config.worker.epoch = config_dict.ConfigDict()
config.worker.epoch.type = 'constant'
config.worker.epoch.mean = 3
config.worker.epoch.std = 2
config.worker.epoch.beta = 0.5
config.worker.epoch.coef = 20
config.worker.epoch.is_random = False

config.worker.tx = config_dict.ConfigDict()
config.worker.tx.fn = optimizers.Adam
config.worker.tx.lr = 0.0003
config.worker.tx.moment = 0.9
config.worker.tx.lr_decay_per_server_epoch = 1.0
config.worker.tx.type_decay_per_server_epoch = 'Geometric'
config.worker.tx.lr_decay_per_worker_epoch = 1.0
config.worker.tx.type_decay_per_worker_epoch = 'Geometric'

# * Compressor Configuration
config.compressor = config_dict.ConfigDict()
config.compressor.enable = True
config.compressor.reset_aggregation = True
config.compressor.use_indiv_reference = False
config.compressor.use_posterior_prior = False
config.compressor.uplink_samples = 1
config.compressor.downlink_samples = 10
config.compressor.common_dl_prior = True
config.compressor.reuse_samples = True
config.compressor.project_kl_divergences = None  # None for no projection, KL divergence in bits else
config.compressor.project_block_kl_divergences = None  # None for no projection, KL divergence in bits else
config.compressor.adaptive_blocks_ul = False
config.compressor.adaptive_blocks_dl = False
config.compressor.adaptive_avg = False
config.compressor.avg_dev_factor = 2
config.compressor.kl_rate_ul = 8.0
config.compressor.kl_rate_dl = 8.0
config.compressor.split_dl = False
config.compressor.max_block_size = 512
config.compressor.block_size = 256
config.compressor.sample_size = 256

config.wandb = config_dict.ConfigDict()
config.wandb.name = ''  # ! TBD in Code
config.wandb.project = 'pm'
config.wandb.job_type = 'pm'

wandb_config = {
    'PARAMETERS': {
        'RANDOM_SEED': config.seed,
        'EXPERIMENT': config.experiment,

        'TRAINING_DATASET': config.data.name,
        'ALLOCATION_TYPE': config.data.alloc_type,
        'ALLOCATION_RATIO': config.data.alloc_ratio,
        'BATCH_LENGTH': config.data.batch_size,
        'DATASET_BETA': config.data.beta,
        'VALIDATION_SAMPLE_NUMBER_PER_CLASS': config.data.num_validation,

        'BETA_MODE': config.beta.mode,
        'BETA_INDEX': config.beta.index,
        'BETA': config.beta.value,

        'SERVER_MAX_ITERATION': config.server.num_epochs,
        'CLIENT_UPDATE_THRESHOLD_RATIO': config.server.update_thresh,
        'EARLY_STOPPING': config.server.es,
        'EARLY_STOPPING_RATE': config.server.es.threshold,

        'CLIENT_NUMBER': config.worker.num,
        'INACTIVE_PROBABILITY': config.worker.inact_prob,
        'UPDATE_MODEL_EVERY_BATCH': config.train.batchwise_update,
        'VARIED_LOCAL_ITERATION': config.worker.epoch.type,
        'LOCAL_ITERATION_TYPE': config.worker.epoch.type,
        'LOCAL_ITERATION_MEAN': config.worker.epoch.mean,
        'LOCAL_ITERATION_STD': config.worker.epoch.std,
        'LOCAL_ITERATION_BETA': config.worker.epoch.beta,
        'LOCAL_ITERATION_COEFFICIENT': config.worker.epoch.coef,
        'LOCAL_LEARNING_RATE': config.worker.tx.lr,
        'LOCAL_LEARNING_RATE_DECAY_PER_SERVER_ITERATION': config.worker.tx.lr_decay_per_server_epoch,
        'LOCAL_LEARNING_RATE_DECAY_PER_LOCAL_ITERATION': config.worker.tx.lr_decay_per_worker_epoch,

        'COMPRESSOR_ENABLE': config.compressor.enable,
        'COMPRESSOR_RESET_AGGREGATION': config.compressor.reset_aggregation,
        'COMPRESSOR_USE_INDIV_REFERENCE': config.compressor.use_indiv_reference,
        'COMPRESSOR_USE_POSTERIOR_PRIOR': config.compressor.use_posterior_prior,
        'UPLINK_SAMPLES': config.compressor.uplink_samples,
        'DOWNLINK_SAMPLES': config.compressor.downlink_samples,
        'COMMON_DL_PRIOR': config.compressor.common_dl_prior,
        'PROJECT_KL_DIVERGENCES': config.compressor.project_kl_divergences,
        'PROJECT_BLOCK_KL_DIVERGENCES': config.compressor.project_block_kl_divergences,
        'ADAPTIVE_BLOCKS_UL': config.compressor.adaptive_blocks_ul,
        'ADAPTIVE_BLOCKS_DL': config.compressor.adaptive_blocks_dl,
        'ADAPTIVE_AVG': config.compressor.adaptive_avg,
        'KL_RATE_UL': config.compressor.kl_rate_ul,
        'KL_RATE_DL': config.compressor.kl_rate_dl,
        'SPLIT_DL': config.compressor.split_dl,
        'MAX_BLOCK_SIZE': config.compressor.max_block_size,
        'BLOCK_SIZE': config.compressor.block_size,
        'SAMPLE_SIZE': config.compressor.sample_size,
        'REUSE_SAMPLES': config.compressor.reuse_samples,
        'AVG_DEV_FACTOR': config.compressor.avg_dev_factor
    }
}
