from sklearn.model_selection import GridSearchCV, ParameterGrid

dfr_c_options = [1., 0.7, 0.3, 0.1, 0.07, 0.03, 0.01]

noreg_grid = {'penalty': ['none'], 'solver': ['lbfgs'], 'max_iter': [1000]}
noreg_params = list(ParameterGrid(noreg_grid))[0]
svm_params = {'kernel': 'linear', 'C': 1.}


l1_grid = lambda c_list, c_name: {c_name: c_list, 'penalty': ['l1'], 'solver': ['liblinear'], 'max_iter': [1000]}
l2_grid = lambda c_list, c_name: {c_name: c_list, 'penalty': ['l2'], 'solver': ['lbfgs'], 'max_iter': [1000]}
l1_l2_grid = lambda c_list, c_name: {c_name: c_list, 'penalty': ['l1', 'l2'], 'solver': ['liblinear'], 'max_iter': [1000]}

get_params = lambda c_list, c_name: list(ParameterGrid(l1_grid(c_list, c_name))) + list(ParameterGrid(l2_grid(c_list, c_name))) + [{'penalty': 'none', 'max_iter': 1000}]
get_sgd_params = lambda c_list, c_name: list(ParameterGrid({c_name: c_list, 'penalty': ['l1']})) + list(ParameterGrid({c_name: c_list, 'penalty': ['l1']})) + [{'penalty': None, 'max_iter': 1000}]

ez_c_options = [100, 10, 1., 0.5, 0.1, 0.01, 0.001]
all_c_options = dfr_c_options + list(set(ez_c_options) - set(dfr_c_options))
dfr_c_options = [1., 0.7, 0.3, 0.1, 0.07, 0.03, 0.01]

knn_params = {'n_neighbors': 5}
knn_grid = {'n_neighbors': [3, 5, 10, 15]}

sgd_params = {'loss': 'hinge', 'penalty': None}
sgd_grid = {'loss': ['hinge', 'log'], 'penalty': ['l1', 'l2'], 'alpha': [1e-4, 1e-3, 1e-2, 0.1, 1.]}

shared_db = {
    'all': get_params(all_c_options, 'C'),
    'gridlr': l1_l2_grid(all_c_options, 'C'),
    'noreg': noreg_params,
    'knn': knn_params,
    'gridknn': knn_grid,
    'sgd': sgd_params,
    'gridsgd': sgd_grid,
    'svm': svm_params,
}

wb_db = {'all': get_params(dfr_c_options, 'C'),
        'gridlr': l1_grid(dfr_c_options, 'C'),
        'sgdcv': sgd_params,
        }

cif_db = {'all': get_sgd_params(ez_c_options, 'alpha'), 
                # get_params(ez_c_options, 'C'),
        'gridlr': l1_l2_grid(ez_c_options, 'C'),
        'sgd': {'loss': 'hinge', 'penalty': 'None'},
        }

bgchallenge_db = {# for sgd classifier
        'all': get_sgd_params(ez_c_options, 'alpha'),
        # list(ParameterGrid(sgd_grid)),
        'gridlr': l1_l2_grid(ez_c_options, 'C'),
        'gridsgd': sgd_grid,
        'sgd': sgd_params,  # 'hinge'
        'knn': knn_params,
}

hard_imagenet_db = {# for sgd classifier
        'all': get_params(ez_c_options, 'C'),
        # list(ParameterGrid(sgd_grid)),
        'gridlr': l1_l2_grid(ez_c_options, 'C'),
        'gridsgd': sgd_grid,
        'sgd': sgd_params,  # 'hinge'
        'knn': knn_params,
}

def update_db(new_db, old_db):
    return {k: (new_db[k] if k in new_db.keys() else v) for k, v in old_db.items()}

hparam_db = {
    # should be a list of dictionaries for 'all', and a hyperparameter grid for gridcv (dict of lists)
    'waterbirds': update_db(wb_db, shared_db),
    'celebA': update_db(wb_db, shared_db),
    'bgchallenge_lin': shared_db,
    'spur_cifar10': update_db(cif_db, shared_db), # update_db(cif_db, shared_db),
    'cmnist': shared_db,
    'metashift': shared_db,
    'bgchallenge': update_db(bgchallenge_db, shared_db),
    'hard_imagenet': update_db(hard_imagenet_db, shared_db),
}

all_modes = ['noreg', 'sgd', 'all', 'knn', 'gridsgd', 'svm', 'gridknn', 'gridlr']
ez_modes = all_modes[:-2]
lin_modes = all_modes[:4]

grid_db = {
    'waterbirds': all_modes,
    'metashift': all_modes,
    'celebA': ez_modes,
    'cmnist': ez_modes,
    'spur_cifar10': lin_modes,
    'hard_imagenet': ['sgd', 'gridsgd', 'all'],
}
