from configs.algorithm import algorithm_defaults

adapter_params = {
    'lr': [1e-3, 1e-4, 1e-5],
    'weight-decay': [5e-5, 1e-5, 5e-4],
    'n_epochs': [100]
}

dataset_hparams = {
    'hard_imagenet': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [0.1],
            'batch-size': [128],
            'n_epochs': [400],
            'weight-decay': [1e-4],
        }, 
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.3], 
            'lr_min': [1e-3],
            'batch-size': [128],
            'weight-decay': [1e-5],
            'n_epochs': [800],
        }
    },
    'celebA': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [0.01],
            'batch-size': [128],
            'weight-decay': [1e-4],
            'n_epochs': [400],
        },
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.3], # initial lr = 0.3 * batch_size / 256
            'lr_min': [1e-3], 
            'batch-size': [256],
            'weight-decay': [1.0e-6], 
            'n_epochs': [400],
        }
    },
    'spur_cifar10': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [0.02],
            'batch-size': [128],
            'n_epochs': [800],
            'weight-decay': [5e-4],
        },
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.3],
            'batch-size': [256],
            'lr_min': [1e-3],
            'weight-decay': [1.0e-6],
            'n_epochs': [1000],
        }
    },
    'cmnist': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [1e-3],
            'batch-size': [128],
            'weight-decay': [1e-5],
            'n_epochs': [1000],
        },
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.6], # initial lr = 0.3 * batch_size / 256
            'lr_min': [1e-3], # initial lr = 0.3 * batch_size / 256
            'batch-size': [512],
            'weight-decay': [1.0e-6], 
            'n_epochs': [1000],
        }
    },
    'waterbirds': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [0.01],
            'batch-size': [64],
            'weight-decay': [1e-3],
            'n_epochs': [800],
        },
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.3], 
            'lr_min': [1e-3], 
            'batch-size': [256],
            'weight-decay': [1.0e-6], 
            'n_epochs': [1000],
        }
    },
    'metashift': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [0.05],
            'weight-decay': [0.001],
            'batch-size': [256],
        },
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.6], 
            'lr_min': [1e-3], 
            'batch-size': [512],
            'weight-decay': [1.0e-6],
        }
    },
    'cifar10': {
        'erm': {
            'algorithm': ['ERM'],
            'lr': [0.01], 
            'weight_decay': [5e-4],
            'scheduler': ['CosineAnnealingLR'],
            'n_epochs': [200]
        }
    },
    'bgchallenge': {
        'simsiam': {
            'algorithm': ['simsiam'],
            'lr': [0.05],
            'batch-size': [256],
            'n_epochs': [800],
            'weight-decay': [1e-4],
        }, 
        'simclr': {
            'algorithm': ['simclr'],
            'lr': [0.3], 
            'lr_min': [1e-3],
            'batch-size': [128],
            'weight-decay': [1e-5],
            'n_epochs': [1000],
        }
    },
}

def hp(dataset, algorithm='simsiam'):
    return {k: [v] for k, v in dataset_hparams[dataset][algorithm].items()}
