mnist_config = {
    'num-workers': 8,
    'optim': 'adam',
    'lr': 1.2e-3,
    'scheduler': None,
    'step_size': None,
    'gamma': None,
    'batch-size': 60,
    'epochs': 50,
    'lr-rewinding': 0,
    'weight-rewinding': 0,
}

cifar10_config = {
    'num-workers': 12,
    'optim': 'sgd_m',
    'lr': 0.1,
    'scheduler': 'steplr',
    'step_size': 30,
    'gamma': 0.1,
    'batch-size': 128,
    'structure_epochs': 5,
    'epochs': 120,
    'lr-rewinding': 0,
    'weight-rewinding': 0,
}

tiny_imagenet_config = {
    'num-workers': 16,
    'optim': 'sgd_m',
    'lr': 0.4,
    'scheduler': 'imagenet',
    'step_size': None,
    'gamma': None,
    'batch-size': 512,
    'epochs': 90,
    'lr-rewinding': 0,
    'weight-rewinding': 0,
}

config = {
    'mnist': mnist_config,
    'cifar10': cifar10_config,
    'tiny-imagenet': tiny_imagenet_config,
}


def get_config(args):
    print("--- Load configuration")
    print(config[args.dataset])
    return config[args.dataset]