# Configuring end model training
prefix: robust
mode: train

# Weights and Biases
wandb_entity: my-entity # change this to your entity
wandb_project: model-patching
wandb_group: stage-2
wandb_job_type: mnist-correlation

# Logical GPUs
logical_gpus: [14336]

# Random seed
seed: 1

# Architecture
model_source: simple_cnn
architecture: None
pretrained: False
pretraining_source: wandb
pretraining_info: ''

# Optimizer
optimizer: SGD
momentum: 0.9

# Loss function
loss_name: sparse_categorical_crossentropy

# Robust hyperparameters
# IRM loss details
irm_anneal_steps: 0
irm_penalty_weight: 0.0

# Group DRO
gdro_adj_coef: 1.0
gdro_lr: 0.01
gdro_mixed: True

# Augmentation consistency
augmentation_training: 'original'
consistency_type: 'camel'
consistency_penalty_weight: 5.0

# Learning rate
lr_scheduler: constant
lr_start: 0.001
lr_end: None
lr_decay_steps: None

# Weight decay for regularization
weight_decay_rate: 0.0005

# Gradient clipping
max_global_grad_norm: -1

# Training details
n_epochs: 200
batch_size: 100
baseline_batch_size: 100

# Metrics to track
metric_names: ['accuracy', 'sparse_categorical_crossentropy']

# Dataset splitting
cross_validation: False
validation_frac: 0.5

# Dataset settings
train_datasets:
  - 'mnist_correlation_yz/zigzag/0/0/19800/-1/y'
  - 'mnist_correlation_yz/zigzag/0/1/200/-1/y'
  - 'mnist_correlation_yz/zigzag/1/0/200/-1/y'
  - 'mnist_correlation_yz/zigzag/1/1/19800/-1/y'
train_dataset_versions:
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
train_dataset_aliases:
  - '(Y=0)(Z=0)'
  - '(Y=0)(Z=1)'
  - '(Y=1)(Z=0)'
  - '(Y=1)(Z=1)'
# train_dataset_modifier: 'shuffle:5000,1'
train_datadirs: # adjust this path to where you want the data to be stored
  - /home/tfdata/
  - /home/tfdata/
  - /home/tfdata/
  - /home/tfdata/
max_shuffle_buffer: -1
shuffle_before_repeat: True

eval_datasets:
  - 'mnist_correlation_yz/zigzag/0/0/19800/-1/y'
  - 'mnist_correlation_yz/zigzag/0/1/200/-1/y'
  - 'mnist_correlation_yz/zigzag/1/0/200/-1/y'
  - 'mnist_correlation_yz/zigzag/1/1/19800/-1/y'
eval_dataset_versions:
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
  - '1.*.*'
eval_dataset_aliases:
  - '(Y=0)(Z=0)'
  - '(Y=0)(Z=1)'
  - '(Y=1)(Z=0)'
  - '(Y=1)(Z=1)'
# eval_dataset_modifier: ''
eval_datadirs:
  - /home/tfdata/
  - /home/tfdata/
  - /home/tfdata/
  - /home/tfdata/

# Static augmentations
train_static_augmentation_pipelines:
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
train_static_augmentation_pipelines_args: # this points to a Weights and Biases run: refer to PretrainedMNISTCycleGANStaticAugmentationPipeline in augment/static.py
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, True]]
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, True]]
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, True]]
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, True]]

train_augmentation_pipelines:
  - [BasicImagePreprocessingPipeline]
  - [BasicImagePreprocessingPipeline]
  - [BasicImagePreprocessingPipeline]
  - [BasicImagePreprocessingPipeline]
train_augmentation_pipelines_args:
  - [[]]
  - [[]]
  - [[]]
  - [[]]
train_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
train_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline

eval_static_augmentation_pipelines:
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
  - [PretrainedMNISTCycleGANStaticAugmentationPipeline]
eval_static_augmentation_pipelines_args:
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, False]]
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, False]]
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, False]]
  - [['my-entity', 'mnist-correlation', 'hfyg9z4t', False, 'batchnorm', -1, False, 128, False]]
eval_augmentation_pipelines:
  - [BasicImagePreprocessingPipeline]
  - [BasicImagePreprocessingPipeline]
  - [BasicImagePreprocessingPipeline]
  - [BasicImagePreprocessingPipeline]
eval_augmentation_pipelines_args:
  - [[]]
  - [[]]
  - [[]]
  - [[]]
eval_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
eval_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline


# Dataflow settings
dataflow: disk_cached
cache_dir: /home/tfcache/

# Path to checkpoints in wandb.run directory
checkpoint_path: '/checkpoints/'

# Checkpoint every _ epochs
checkpoint_freq: 1
