VERBOSE = False

MODEL_HYPER_PARAMS = {
    'graph_model': 'SPNFG',
    'sample_mask': True,
    'hard_mask': True,
    'fix_factor': False,
    'z_prior': {
        'mu': 0.0,
        'std': 1.0
    },
    'edge_prior': 0.5,
    'nonlin': 'relu',
    'lr_nn': 1e-4,
    'lr_fg': 1e-3,
    'loss_coeff': {
        'kl_z': 0.25,
        'kl_fg': 0.25,
        'l1_reg': 1.0,
        'l1_reg_int': 1.0
    },
    'coeff_scheduler': 'sigmoid',
    'loss_memory': 0.9,
    'var_type': 'const',
    'noise_level': 0.05,
    'robust_loss': False
}


OPTIM_PARAMS = {
    'weight_decay': 1e-3
}


GRAPH_MODEL_PARAMS = {
    'BasicFG': {
        'tau': 1.0,
    },
    'SPNFG': {
        'spn_target': 'node',
        'max_copies': 8,
        'tau': 1.0,
        'p_conn': 0.1,
        'sparsity_temp': 1.0
    }
}

SCHEDULER_PARAMS = {
    'const': {
        'start_value': 1.0
    },
    'linear': {
        'start_value': 1.0,
        'end_value': 0.0,
        'num_epochs': 1000
    },
    'cosine': {
        'start_value': 0.0,
        'end_value': 1.0,
        'period': 100
    },
    'sigmoid': {
        'start_value': 0.0,
        'end_value': 1.0,
        'midpoint': 500,
        'scale': 125
    },
    'sparsity': {
        'min_l1': 0.0,
        'start_value': 0.0,
        'lambda_l1': 0.05
    }
}


def get_model_params(key):
    return MODEL_HYPER_PARAMS[key]


def set_model_params(**params):
    for key, val in params.items():
        assert key in MODEL_HYPER_PARAMS, f"Invalid key {key}"
        MODEL_HYPER_PARAMS[key] = val


def get_graph_params(model_name):
    return GRAPH_MODEL_PARAMS[model_name]

def get_scheduler_params(key):
    return SCHEDULER_PARAMS[key]


def set_graph_params(model_name, **params):
    for key, val in params.items():
        assert key in GRAPH_MODEL_PARAMS[model_name], f"Invalid key {key} for model {model_name}"
        GRAPH_MODEL_PARAMS[model_name][key] = val


def get_optim_params(key):
    return OPTIM_PARAMS[key]


def set_optim_params(**params):
    for key, val in params.items():
        assert key in OPTIM_PARAMS, f"Invalid key {key}"
        OPTIM_PARAMS[key] = val

def set_scheduler_params(scheduler, **params):
    for key, val in params.items():
        assert key in SCHEDULER_PARAMS[scheduler], f"Invalid key {key}"
        SCHEDULER_PARAMS[scheduler][key] = val
