# Configuring end model training
prefix: robust
mode: train

# Weights and Biases
wandb_entity: my-entity
wandb_project: model-patching
wandb_group: stage-2
wandb_job_type: waterbirds

# Logical GPUs
logical_gpus: [14336]

# Random seed
seed: 1

# Architecture
model_source: cm
architecture: resnet50
pretrained: True
pretraining_source: wandb
pretraining_info: ''

# Optimizer
optimizer: SGD
momentum: 0.9

# Loss function
loss_name: sparse_categorical_crossentropy

# Robust hyperparameters
# IRM loss details
irm_anneal_steps: 0
irm_penalty_weight: 0.0

# Group DRO
gdro_adj_coef: 2.0
gdro_lr: 0.01
gdro_mixed: True

# Augmentation consistency
augmentation_training: 'original'
consistency_type: 'camel'
consistency_penalty_weight: 100.0

# Learning rate
lr_scheduler: constant
lr_start: 0.0001
lr_end: None
lr_decay_steps: None

# Weight decay for regularization
weight_decay_rate: 0.001

# Gradient clipping
max_global_grad_norm: -1

# Training details
n_epochs: 500
batch_size: 16
baseline_batch_size: 128

# Metrics to track
metric_names: ['accuracy', 'sparse_categorical_crossentropy']

# Dataset splitting
cross_validation: False
validation_frac: 0.

# Dataset settings
train_datasets:
  - 'waterbirds/0/0/y'
  - 'waterbirds/0/1/y'
  - 'waterbirds/1/0/y'
  - 'waterbirds/1/1/y'
train_dataset_versions:
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
train_dataset_aliases:
  - '(Y=0)(Z=0)'
  - '(Y=0)(Z=1)'
  - '(Y=1)(Z=0)'
  - '(Y=1)(Z=1)'
train_datadirs:
  - /home/waterbirds_tfrecord_224
  - /home/waterbirds_tfrecord_224
  - /home/waterbirds_tfrecord_224
  - /home/waterbirds_tfrecord_224
# max_shuffle_buffer: -1
# shuffle_before_repeat: False

eval_datasets:
  - 'waterbirds/0/0/y'
  - 'waterbirds/0/1/y'
  - 'waterbirds/1/0/y'
  - 'waterbirds/1/1/y'
eval_dataset_versions:
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
eval_dataset_aliases:
  - '(Y=0)(Z=0)'
  - '(Y=0)(Z=1)'
  - '(Y=1)(Z=0)'
  - '(Y=1)(Z=1)'
eval_datadirs:
  - /home/waterbirds_tfrecord_224
  - /home/waterbirds_tfrecord_224
  - /home/waterbirds_tfrecord_224
  - /home/waterbirds_tfrecord_224


# Static augmentations
train_static_augmentation_pipelines:
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
train_static_augmentation_pipelines_args: # this points to a Weights and Biases run: refer to PretrainedCycleGANStaticAugmentationTFRecordPipeline in augment/static.py
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', '5f2gmy7w', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, True]]
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', '5f2gmy7w', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, True]]
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', 'vla2y0m7', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, True]]
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', 'vla2y0m7', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, True]]
train_augmentation_pipelines:
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
train_augmentation_pipelines_args:
  - [[]]
  - [[]]
  - [[]]
  - [[]]
train_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
train_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline

eval_static_augmentation_pipelines:
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
  - [PretrainedCycleGANStaticAugmentationTFRecordPipeline]
eval_static_augmentation_pipelines_args:
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', '5f2gmy7w', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, False]]
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', '5f2gmy7w', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, False]]
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', 'vla2y0m7', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, False]]
  - [['/home/tfdata/augmented-datasets/', 'my-entity', 'waterbirds', 'vla2y0m7', False, [224, 224, 3], 'batchnorm', -1, False, '', 4, False]]
eval_augmentation_pipelines:
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
  - [ImageNetPreprocessingPipeline]
eval_augmentation_pipelines_args:
  - [[]]
  - [[]]
  - [[]]
  - [[]]
eval_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
eval_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline

# Dataflow settings
dataflow: disk_cached
cache_dir: /home/tfcache/

# Path to checkpoints in wandb.run directory
checkpoint_path: '/checkpoints/'

# Checkpoint every _ epochs
checkpoint_freq: 1
