datasets =  {
'celeba': {'num_classes': 2,
          'num_envs': 4,
          'hidden_layer_size': 2048},
'waterbirds': {'num_classes': 2,
          'num_envs': 4,
          'hidden_layer_size': 2048},

'civilcomments': {'num_classes': 2,
          'num_envs': 16,
          'hidden_layer_size': 768},

'multinli': {'num_classes': 3,
          'num_envs': 6,
          'hidden_layer_size': 768},

'cmnist': {'num_classes': 5,
          'num_envs': 25,
          'hidden_layer_size': 768},

'urbancars': {'num_classes': 2,
            'num_envs': 8,
            'hidden_layer_size': 2048},

'domino': {'num_classes': 2,
            'num_envs': 8,
            'hidden_layer_size': 512},
}