dataset_defaults = {
    'amazon': {
        'split_scheme': 'official',
        'model': 'distilbert-base-uncased',
        'transform': 'bert',
        'max_token_length': 512,
        'loss_function': 'cross_entropy',
        'algo_log_metric': 'accuracy',
        'batch_size': 8,
        'lr': 1e-5,
        'weight_decay': 0.01,
        'n_epochs': 3,
        'n_groups_per_batch': 2,
        'irm_lambda': 1.0,
        'coral_penalty_weight': 1.0,
        'loader_kwargs': {
            'num_workers': 1,
            'pin_memory': True,
        },
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'bdd100k': {
        'split_scheme': 'official',
        'model': 'resnet50',
        'model_kwargs': {'pretrained': True},
        'loss_function': 'multitask_bce',
        'val_metric': 'acc_all',
        'val_metric_decreasing': False,
        'optimizer': 'SGD',
        'optimizer_kwargs': {'momentum': 0.9},
        'batch_size': 32,
        'lr': 0.001,
        'weight_decay': 0.0001,
        'n_epochs': 10,
        'algo_log_metric': 'multitask_binary_accuracy',
        'transform': 'image_base',
        'process_outputs_function': 'binary_logits_to_pred',
    },
    'camelyon17': {
        'split_scheme': 'official',
        'model': 'densenet121',
        'model_kwargs': {'pretrained': False},
        'transform': 'image_base',
        'target_resolution': (96, 96),
        'loss_function': 'cross_entropy',
        'groupby_fields': ['hospital'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
        'optimizer': 'SGD',
        'optimizer_kwargs': {'momentum': 0.9},
        'scheduler': None,
        'batch_size': 400,
        'lr': 1e-5,
        'weight_decay': 0.01,
        'n_epochs': 3,
        'n_groups_per_batch': 2,
        'irm_lambda': 1.0,
        'coral_penalty_weight': 0.1,
        'algo_log_metric': 'accuracy',
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'celebA': {
        'split_scheme': 'official',
        'model': 'resnet50',
        'model_kwargs': {'pretrained': True},
        'transform': 'image_base',
        'loss_function': 'cross_entropy',
        'groupby_fields': ['male', 'y'],
        'val_metric': 'acc_wg',
        'val_metric_decreasing': False,
        'optimizer': 'SGD',
        'optimizer_kwargs': {'momentum': 0.9},
        'scheduler': None,
        'batch_size': 64,
        'lr': 0.001,
        'weight_decay': 0.0,
        'n_epochs': 200,
        'algo_log_metric': 'accuracy',
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'civilcomments': {
        'split_scheme': 'official',
        'model': 'distilbert-base-uncased',
        'transform': 'bert',
        'loss_function': 'cross_entropy',
        'groupby_fields': ['black', 'y'],
        'val_metric': 'acc_wg',
        'val_metric_decreasing': False,
        'batch_size': 16,
        'lr': 1e-5,
        'weight_decay': 0.01,
        'n_epochs': 5,
        'algo_log_metric': 'accuracy',
        'max_token_length': 300,
        'irm_lambda': 1.0,
        'coral_penalty_weight': 10.0,
        'loader_kwargs': {
            'num_workers': 1,
            'pin_memory': True,
        },
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'encode': {
        'split_scheme': 'official',
        'model': 'unet-seq',
        'model_kwargs': {'n_channels_in': 5},
        'loader_kwargs': {'num_workers': 1}, # pybigwig seems to have trouble with multiprocessing
        'train_transform': None,
        'eval_transform': None,
        'loss_function': 'multitask_bce',
        'groupby_fields': ['celltype'],
        'val_metric': 'avgprec-macro_all',
        'val_metric_decreasing': False,
        'optimizer': 'Adam',
        'scheduler': 'MultiStepLR',
        'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1},
        'batch_size': 128,
        'lr': 1e-3,
        'weight_decay': 1e-4,
        'n_epochs': 12,
        'n_groups_per_batch': 4,
        'algo_log_metric': 'multitask_binary_accuracy',
        'irm_lambda': 100.0,
        'coral_penalty_weight': 0.1,
    },
    'fmow': {
        'split_scheme': 'official',
        'dataset_kwargs': {            
            'seed': 111,
            'use_ood_val': True
        },
        'model': 'densenet121',
        'model_kwargs': {'pretrained': True},
        'transform': 'image_base',
        'loss_function': 'cross_entropy',
        'groupby_fields': ['year',],
        'val_metric': 'acc_worst_region',
        'val_metric_decreasing': False,
        'optimizer': 'Adam',
        'scheduler': 'StepLR',
        'scheduler_kwargs': {'gamma': 0.96},
        'batch_size': 64,
        'lr': 0.0001,
        'weight_decay': 0.0,
        'n_epochs': 50,
        'n_groups_per_batch': 8,
        'irm_lambda': 1.0,
        'coral_penalty_weight': 0.1,
        'algo_log_metric': 'accuracy',
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'iwildcam': {
        'loss_function': 'cross_entropy',
        'val_metric': 'F1-macro_all',
        'model_kwargs': {'pretrained': True},
        'transform': 'image_base',
        'target_resolution': (448, 448),
        'val_metric_decreasing': False,
        'algo_log_metric': 'accuracy',
        'model': 'resnet50',
        'lr': 3e-5,
        'weight_decay': 0.0,
        'batch_size': 16,
        'n_epochs': 12,
        'optimizer': 'Adam',
        'split_scheme': 'official',
        'scheduler': None,
        'groupby_fields': ['location',],
        'n_groups_per_batch': 2,
        'irm_lambda': 1.,
        'coral_penalty_weight': 10.,
        'no_group_logging': True,
        'process_outputs_function': 'multiclass_logits_to_pred'
    },
    'ogb-molpcba': {
        'split_scheme': 'official',
        'model': 'gin-virtual',
        'model_kwargs': {'dropout':0.5}, # include pretrained
        'loss_function': 'multitask_bce',
        'groupby_fields': ['scaffold',],
        'val_metric': 'ap',
        'val_metric_decreasing': False,
        'optimizer': 'Adam',
        'batch_size': 32,
        'lr': 1e-03,
        'weight_decay': 0.,
        'n_epochs': 100,
        'n_groups_per_batch': 4,
        'irm_lambda': 1.,
        'coral_penalty_weight': 0.1,
        'no_group_logging': True,
        'process_outputs_function': None,
        'algo_log_metric': 'multitask_binary_accuracy',
    },
    'py150': {
        'split_scheme': 'official',
        'model': 'code-gpt-py',
        'loss_function': 'lm_cross_entropy',
        'val_metric': 'acc',
        'val_metric_decreasing': False,
        'optimizer': 'AdamW',
        'optimizer_kwargs': {'eps':1e-8},
        'lr': 8e-5,
        'weight_decay': 0.,
        'n_epochs': 3,
        'batch_size': 6,
        'groupby_fields': ['repo',],
        'n_groups_per_batch': 2,
        'irm_lambda': 1.,
        'coral_penalty_weight': 1.,
        'no_group_logging': True,
        'algo_log_metric': 'multitask_accuracy',
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'poverty': {
        'split_scheme': 'official',
        'dataset_kwargs': {
            'no_nl': False,
            'fold': 'A',
            'use_ood_val': True
        },
        'model': 'resnet18_ms',
        'model_kwargs': {'num_channels': 8},
        'transform': 'poverty',
        'loss_function': 'mse',
        'groupby_fields': ['country',],
        'val_metric': 'r_wg',
        'val_metric_decreasing': False,
        'algo_log_metric': 'mse',
        'optimizer': 'Adam',
        'scheduler': 'StepLR',
        'scheduler_kwargs': {'gamma':0.96},
        'batch_size': 64,
        'lr': 0.001,
        'weight_decay': 0.0,
        'n_epochs': 200,
        'n_groups_per_batch': 8,
        'irm_lambda': 1.0,
        'coral_penalty_weight': 0.1,
        'process_outputs_function': None,
    },
    'waterbirds': {
        'split_scheme': 'official',
        'model': 'resnet50',
        'transform': 'image_resize_and_center_crop',
        'resize_scale': 256.0/224.0,
        'model_kwargs': {'pretrained': True},
        'loss_function': 'cross_entropy',
        'groupby_fields': ['background', 'y'],
        'val_metric': 'acc_wg',
        'val_metric_decreasing': False,
        'algo_log_metric': 'accuracy',
        'optimizer': 'SGD',
        'optimizer_kwargs': {'momentum':0.9},
        'scheduler': None,
        'batch_size': 128,
        'lr': 1e-5,
        'weight_decay': 1.0,
        'n_epochs': 300,
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'yelp': {
        'split_scheme': 'official',
        'model': 'bert-base-uncased',
        'transform': 'bert',
        'max_token_length': 512,
        'loss_function': 'cross_entropy',
        'algo_log_metric': 'accuracy',
        'batch_size': 8,
        'lr': 2e-6,
        'weight_decay': 0.01,
        'n_epochs': 3,
        'n_groups_per_batch': 2,
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'sqf': {
        'split_scheme': 'all_race',
        'model': 'logistic_regression',
        'transform': None,
        'model_kwargs': {'in_features': 104},
        'loss_function': 'cross_entropy',
        'groupby_fields': ['y'],
        'val_metric': 'precision_at_global_recall_all',
        'val_metric_decreasing': False,
        'algo_log_metric': 'accuracy',
        'optimizer': 'Adam',
        'optimizer_kwargs': {},
        'scheduler': None,
        'batch_size': 4,
        'lr': 5e-5,
        'weight_decay': 0,
        'n_epochs': 4,
        'process_outputs_function': None,
    },
    'rxrx1': {
        'split_scheme': 'official',
        'model': 'resnet50',
        'model_kwargs': {'pretrained': True},
        'transform': 'rxrx1',
        'target_resolution': (256, 256),
        'loss_function': 'cross_entropy',
        'groupby_fields': ['experiment'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
        'algo_log_metric': 'accuracy',
        'optimizer': 'Adam',
        'optimizer_kwargs': {},
        'scheduler': 'cosine_schedule_with_warmup',
        'scheduler_kwargs': {'num_warmup_steps': 5415},
        'batch_size': 72,
        'lr': 1e-3,
        'weight_decay': 1e-5,
        'n_groups_per_batch': 9,
        'coral_penalty_weight': 0.1,
        'irm_lambda': 1.0,
        'n_epochs': 90,
        'process_outputs_function': 'multiclass_logits_to_pred',
    },
    'globalwheat': {
        'split_scheme': 'official',
        'model': 'fasterrcnn',
        'transform': 'image_base',
        'model_kwargs': {
            'n_classes': 1,
            'pretrained': True
        },
        'loss_function': 'fasterrcnn_criterion',
        'groupby_fields': ['session'],
        'val_metric': 'detection_acc_avg_dom',
        'val_metric_decreasing': False,
        'algo_log_metric': None, # TODO
        'optimizer': 'Adam',
        'optimizer_kwargs': {},
        'scheduler': None,
        'batch_size': 4,
        'lr': 1e-5,
        'weight_decay': 1e-3,
        'n_epochs': 10,
        'loader_kwargs': {
            'num_workers': 1,
            'pin_memory': True,
        },
        'process_outputs_function': None,
    }
}

##########################################
### Split-specific defaults for Amazon ###
##########################################

amazon_split_defaults = {
    'official':{
        'groupby_fields': ['user'],
        'val_metric': '10th_percentile_acc',
        'val_metric_decreasing': False,
        'no_group_logging': True,
    },
    'user':{
        'groupby_fields': ['user'],
        'val_metric': '10th_percentile_acc',
        'val_metric_decreasing': False,
        'no_group_logging': True,
    },
    'time':{
        'groupby_fields': ['year'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
    },
    'time_baseline':{
        'groupby_fields': ['year'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
    },
}

user_baseline_splits = [
    'A1CNQTCRQ35IMM_baseline', 'A1NE43T0OM6NNX_baseline', 'A1UH21GLZTYYR5_baseline', 'A20EEWWSFMZ1PN_baseline',
    'A219Y76LD1VP4N_baseline', 'A37BRR2L8PX3R2_baseline', 'A3JVZY05VLMYEM_baseline', 'A9Q28YTLYREO7_baseline',
    'ASVY5XSYJ1XOE_baseline', 'AV6QDP8Q0ONK4_baseline'
    ]
for split in user_baseline_splits:
    amazon_split_defaults[split] = {
        'groupby_fields': ['user'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
        }

category_splits = [
    'arts_crafts_and_sewing_generalization', 'automotive_generalization',
    'books,movies_and_tv,home_and_kitchen,electronics_generalization', 'books_generalization', 'category_subpopulation',
    'cds_and_vinyl_generalization', 'cell_phones_and_accessories_generalization', 'clothing_shoes_and_jewelry_generalization',
    'digital_music_generalization', 'electronics_generalization', 'grocery_and_gourmet_food_generalization',
    'home_and_kitchen_generalization', 'industrial_and_scientific_generalization', 'kindle_store_generalization',
    'luxury_beauty_generalization', 'movies_and_tv,books,home_and_kitchen_generalization', 'movies_and_tv,books_generalization',
    'movies_and_tv_generalization', 'musical_instruments_generalization', 'office_products_generalization',
    'patio_lawn_and_garden_generalization', 'pet_supplies_generalization', 'prime_pantry_generalization',
    'sports_and_outdoors_generalization', 'tools_and_home_improvement_generalization', 'toys_and_games_generalization',
    'video_games_generalization',
    ]
for split in category_splits:
    amazon_split_defaults[split] = {
        'groupby_fields': ['category'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
        }

########################################
### Split-specific defaults for Yelp ###
########################################

yelp_split_defaults = {
    'official':{
        'groupby_fields': ['year'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
    },
    'user':{
        'groupby_fields': ['user'],
        'val_metric': '10th_percentile_acc',
        'val_metric_decreasing': False,
        'no_group_logging': True,
    },
    'time':{
        'groupby_fields': ['year'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
    },
    'time_baseline':{
        'groupby_fields': ['year'],
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
    },
}

###############################
### Split-specific defaults ###
###############################

split_defaults = {
    'amazon': amazon_split_defaults,
    'yelp': yelp_split_defaults,
}
