"""
Hyperparameters in this config file have been set according to the ERM runs reported in the paper.
"""

dataset_defaults = {
    'camelyon17': {
        'split_scheme': 'official',
        'model': 'densenet121',
        'model_kwargs': {'pretrained': False},
        'train_base_transforms': 'image_base',
        'eval_base_transforms': 'image_base',
        'target_resolution': (96, 96),
        'loss_function': 'cross_entropy',
        'groupby_fields': ['hospital'],
        'val_split': 'val',
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
        'batch_size': 128,
        'gradient_accumulation_steps': 1,
        'lr': 0.003898422290069297, # fixed across all runs
        'weight_decay': 0.01,
        'n_epochs': 10,
        'n_groups_per_batch': 2,
        'irm_lambda': 10,
        'coral_penalty_weight': 1,
        'dann_penalty_weight': 0.01266752946816601,
        'dann_featurizer_lr': 0.0030693212138627936,
        'dann_classifier_lr': 0.030693212138627936,
        'dann_discriminator_lr': 0.001975009770387647,
        'algo_log_metric': 'accuracy',
        'process_outputs_function': 'multiclass_logits_to_pred',
        'optimizer': 'SGD',
        'optimizer_kwargs': {'momentum': 0.9},
        'scheduler': None,
    },
    'iwildcam': {
        'loss_function': 'cross_entropy',
        'model': 'resnet50',
        'train_base_transforms': 'image_base',
        'eval_base_transforms': 'image_base',
        'val_metric_decreasing': False,
        'val_split': 'val',
        'val_metric': 'F1-macro_all',
        'algo_log_metric': 'accuracy',
        'lr': 1.2352813497608926e-05,
        'weight_decay': 0.0,
        'n_epochs': 15,
        'batch_size': 24,
        'gradient_accumulation_steps': 1,
        'split_scheme': 'official',
        'groupby_fields': ['location'],
        'n_groups_per_batch': 2,
        'irm_lambda': 1.,
        'coral_penalty_weight': 1.,
        'dann_penalty_weight': 0.5454949490827471,
        'dann_featurizer_lr': 3.274083574285868e-06,
        'dann_classifier_lr': 3.274083574285868e-05,
        'dann_discriminator_lr': 9.364480761173232e-05,
        'no_group_logging': True,
        'process_outputs_function': 'multiclass_logits_to_pred',
        'model_kwargs': {'pretrained': True},
        'target_resolution': (448, 448),
        'optimizer': 'Adam',
        'scheduler': None,
    },
    'birdcalls': {
        'split_scheme': 'combined',
        'model': 'efficientnet-b0',
        'train_base_transforms': 'image_base',
        'eval_base_transforms': 'image_base',
        'loss_function': 'cross_entropy',
        'groupby_fields': ['location'],
        'n_groups_per_batch': 2,
        'optimizer': 'Adam',
        'scheduler': None,
        'val_split': 'id_val',
        'val_metric': 'acc_avg',
        'val_metric_decreasing': False,
        'batch_size': 64,
        'gradient_accumulation_steps': 1,
        'lr': 0.0007879391555285274,
        'weight_decay': 0.001,
        'n_epochs': 100,
        'algo_log_metric': 'accuracy',
        'process_outputs_function': 'multiclass_logits_to_pred',
        'irm_lambda': 1.,
        'coral_penalty_weight': 0.1,
        'dann_penalty_weight': 2.984284997073748,
        'dann_featurizer_lr': 0.0005208790981260641,
        'dann_classifier_lr': 0.005208790981260641,
        'dann_discriminator_lr': 0.00011898360266164227,
        'model_kwargs': {'pretrained': True},
        'target_resolution': (224, 224),
    },
}