# Config information
_config_path: null
_template_config_path: null
parent_template: None

# Configuring end model training
prefix: robust
mode: train

# Weights and Biases: point these to your own credentials
wandb_entity: my-entity
wandb_project: model-patching
wandb_group: stage-2
wandb_job_type: training

# Logical GPUs: we used T4s, P100s and V100s for all our experiments.
logical_gpus: [4096, 10240]
dtype: float32

# Random seed: set a seed for the experiment, our final numbers are reported on seeds 2, 3 and 4.
seed: 1

# Architecture
model_source: cm # uses a model from the classification-models repository
architecture: resnet50 # all our experiments use ResNet-50 models
pretrained: True
pretraining_source: wandb
pretraining_info: ''

# Optimizer
optimizer: SGD
momentum: 0.9

# Loss function
loss_name: sparse_categorical_crossentropy

# Robust hyperparameters
# ------------------------
# IRM loss details
irm_anneal_steps: 100
irm_penalty_weight: 0.0

# Group DRO
gdro_adj_coef: 0.0
gdro_lr: 0.0
gdro_mixed: False

# Augmentation consistency
augmentation_training: 'original' # 'original' | 'augmented' | 'both'
consistency_type: 'camel' # 'augment' | 'triplet' | 'kl' | 'reverse-kl' | 'none'
consistency_penalty_weight: 0.0

# Learning rate
lr_scheduler: constant
lr_start: 0.001
lr_end: None
lr_decay_steps: None

# Weight decay for regularization
weight_decay_rate: 0.0005

# Gradient clipping
max_global_grad_norm: 5.0

# Training details
n_epochs: 10
batch_size: 8
baseline_batch_size: 128

# Metrics to track
metric_names: ['accuracy', 'confusion_matrix', 'sparse_categorical_crossentropy']

# Dataset splitting
cross_validation: False
validation_frac: 0.0

# Train dataset settings
train_datasets:
  - mnist
train_dataset_versions:
  - 3.*.*
train_dataset_aliases:
  - mnist
train_datadirs:
  - /home/tfdata/
# train_dataset_modifier: ''
train_shuffle_seeds: null
max_shuffle_buffer: -1
shuffle_before_repeat: False

# Eval dataset settings
eval_datasets:
  - mnist
eval_dataset_versions:
  - 3.*.*
eval_dataset_aliases:
  - mnist
eval_datadirs:
  - /home/tfdata/
# eval_dataset_modifier: ''

# Train augmentations: see examples for details
train_static_augmentation_pipelines:
  - [] # One pipeline per dataset
train_static_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline
train_augmentation_pipelines:
  - [] # One pipeline per dataset
train_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline
train_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
train_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline

# Eval augmentations
eval_static_augmentation_pipelines:
  - [] # One pipeline per dataset
eval_static_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline
eval_augmentation_pipelines:
  - [] # One pipeline per dataset
eval_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline
eval_gpu_augmentation_pipelines:
  - [] # One pipeline per dataset
eval_gpu_augmentation_pipelines_args:
  - [] # The corresponding arguments to pass to each augmentation in each pipeline

# Dataflow settings: set a cache directory on your disk
dataflow: disk_cached
cache_dir: /home/tfcache/

# Path to checkpoints in wandb.run directory
checkpoint_path: '/checkpoints/'

# Checkpoint every _ epochs
checkpoint_freq: 20

# Resuming an old run from Weights and Biases
resume: False
prev_wandb_run_id: null
prev_wandb_project: null
prev_wandb_entity: my-entity
prev_ckpt_path: '' # where are the checkpoints located in wandb (relative to this run's root path)
prev_ckpt_epoch: -1

# Running in a Jupyter notebook
jupyter_notebook: False
