WEIGHTS_PRUNING_ATTR = 'mask_pruning'
WEIGHTS_BASE_ATTR = 'mask_base'

WEIGHTS_ATTR = 'weights'
BIAS_ATTR = 'bias'
MASK_MERGED_ATTR = 'mask_merged'

CONV2D_LAYER = 'conv2d'
BATCH_NORM_2D_LAYER = 'batch_norm_2d'
DROPOUT_LAYER = 'dropout'
FULLY_CONNECTED_LAYER = 'fully_connected'

PRUNED_MODELS_PATH = '/networks_pruned'
BASELINE_MODELS_PATH = '/networks_baseline'
IMAGENET_PATH = "/todo"

DATA_PATH = '/data'
EXPERIMENTS_RESULTS_PATH = '/experiments_outputs'

LR_FLOW_PARAMS = {
    "value": 0.0
}
LR_FLOW_PARAMS_RESET = {
    "value": 0.0
}
FLOW_PARAMS_INITIALIZATION = {
    "value": 0.0
}

INITIAL_LR = 0.1

def config_sgd_setup():
    global LR_FLOW_PARAMS, LR_FLOW_PARAMS_RESET, LR_FLOW_PARAMS_ADAM_RESET, LR_FLOW_PARAMS_ADAM, FLOW_PARAMS_INITIALIZATION
    LR_FLOW_PARAMS["value"] = LR_FLOW_PARAMS_SGD
    LR_FLOW_PARAMS_RESET["value"] = LR_FLOW_PARAMS_SGD_RESET
    FLOW_PARAMS_INITIALIZATION["value"] = FLOW_PARAMS_INITIALIZATION_SGD

def config_adam_setup():
    global LR_FLOW_PARAMS, LR_FLOW_PARAMS_RESET, LR_FLOW_PARAMS_ADAM_RESET, LR_FLOW_PARAMS_ADAM, FLOW_PARAMS_INITIALIZATION

    LR_FLOW_PARAMS["value"] = LR_FLOW_PARAMS_ADAM
    LR_FLOW_PARAMS_RESET["value"] = LR_FLOW_PARAMS_ADAM_RESET
    FLOW_PARAMS_INITIALIZATION["value"] = FLOW_PARAMS_INITIALIZATION_ADAM

def get_lr_flow_params():
    return LR_FLOW_PARAMS["value"]

def get_lr_flow_params_reset():
    return LR_FLOW_PARAMS_RESET["value"]

def get_flow_params_init():
    return FLOW_PARAMS_INITIALIZATION["value"]


# LR_FLOW_PARAMS_ADAM = 0.00075
LR_FLOW_PARAMS_ADAM = 0.001

LR_FLOW_PARAMS_ADAM_RESET = LR_FLOW_PARAMS_ADAM / 10
FLOW_PARAMS_INITIALIZATION_ADAM = 0.2

LR_FLOW_PARAMS_SGD = 0.001
LR_FLOW_PARAMS_SGD_RESET = LR_FLOW_PARAMS_SGD / 10
FLOW_PARAMS_INITIALIZATION_SGD = 0.0002

GRADIENT_IDENTITY_SCALER = 0.25
SCHEDULER_MESSAGE = "SCHEDULER::"

N_SCALER = 3.75e-06

TRAIN_EPOCHS_RESNET18_CIFAR10 = 200
TRAIN_EPOCHS_RESNET50_CIFAR10 = 200
TRAIN_EPOCHS_VGG19_CIFAR10 = 200

TRAIN_EPOCHS_RESNET18_CIFAR100 = 200
TRAIN_EPOCHS_RESNET50_CIFAR100 = 200
TRAIN_EPOCHS_VGG19_CIFAR100 = 200

TRAIN_EPOCHS_RESNET50_IMAGENET = 120