# Configuring end model training
prefix: robust
mode: train

# Weights and Biases
wandb_entity: my-entity
wandb_project: model-patching
wandb_group: stage-2
wandb_job_type: celeba

# Logical GPUs
logical_gpus: [14336]

# Random seed
seed: 1

# Architecture
model_source: cm
architecture: resnet50
pretrained: True
pretraining_source: wandb
pretraining_info: ''

# Optimizer
optimizer: SGD
momentum: 0.9

# Loss function
loss_name: sparse_categorical_crossentropy

# Robust hyperparameters
# IRM loss details
irm_anneal_steps: 0
irm_penalty_weight: 0.0

# Group DRO
gdro_adj_coef: 3.0
gdro_lr: 0.01
gdro_mixed: False

# Augmentation consistency
augmentation_training: 'original'
consistency_type: 'none'
consistency_penalty_weight: 0.0

# Learning rate
lr_scheduler: constant
lr_start: 0.0001
lr_end: None
lr_decay_steps: None

# Weight decay for regularization
weight_decay_rate: 0.05

# Gradient clipping
max_global_grad_norm: -1

# Training details
n_epochs: 100
batch_size: 16
baseline_batch_size: 128

# Metrics to track
metric_names: ['accuracy', 'sparse_categorical_crossentropy']

# Dataset splitting
cross_validation: False
validation_frac: 0.

# Dataset settings
train_datasets:
  - 'celeb_a_128/Blond_Hair/Male/1.0/0/0/y/4054'
  - 'celeb_a_128/Blond_Hair/Male/1.0/0/1/y'
  - 'celeb_a_128/Blond_Hair/Male/1.0/1/0/y'
  - 'celeb_a_128/Blond_Hair/Male/1.0/1/1/y'
train_dataset_versions:
  - '2.*.*'
  - '2.*.*'
  - '2.*.*'
  - '2.*.*'
train_dataset_aliases:
  - '(Y=0)(Z=0)'
  - '(Y=0)(Z=1)'
  - '(Y=1)(Z=0)'
  - '(Y=1)(Z=1)'
# train_dataset_modifier: 'shuffle:5000,1'
train_datadirs:
  - /home/celeba_tfrecord_128
  - /home/celeba_tfrecord_128
  - /home/celeba_tfrecord_128
  - /home/celeba_tfrecord_128
max_shuffle_buffer: 5000

eval_datasets:
  - 'celeb_a_128/Blond_Hair/Male/1.0/0/0/y'
  - 'celeb_a_128/Blond_Hair/Male/1.0/0/1/y'
  - 'celeb_a_128/Blond_Hair/Male/1.0/1/0/y'
  - 'celeb_a_128/Blond_Hair/Male/1.0/1/1/y'
eval_dataset_versions:
  - '2.*.*'
  - '2.*.*'
  - '2.*.*'
  - '2.*.*'
eval_dataset_aliases:
  - '(Y=0)(Z=0)'
  - '(Y=0)(Z=1)'
  - '(Y=1)(Z=0)'
  - '(Y=1)(Z=1)'
# eval_dataset_modifier: ''
eval_datadirs:
  - /home/celeba_tfrecord_128
  - /home/celeba_tfrecord_128
  - /home/celeba_tfrecord_128
  - /home/celeba_tfrecord_128

# Static augmentations
train_static_augmentation_pipelines:
  - []
  - []
  - []
  - []
train_static_augmentation_pipelines_args:
  - []
  - []
  - []
  - []

train_augmentation_pipelines:
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
train_augmentation_pipelines_args:
  - [[]]
  - [[]]
  - [[]]
  - [[]]
train_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
train_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline

eval_static_augmentation_pipelines:
  - []
  - []
  - []
  - []
eval_static_augmentation_pipelines_args:
  - []
  - []
  - []
  - []
eval_augmentation_pipelines:
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
eval_augmentation_pipelines_args:
  - [[]]
  - [[]]
  - [[]]
  - [[]]
eval_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
eval_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline


# Dataflow settings
dataflow: disk_cached
cache_dir: /home/tfcache/

# Path to checkpoints in wandb.run directory
checkpoint_path: '/checkpoints/'

# Checkpoint every _ epochs
checkpoint_freq: 1
