# Model
model_directory: './pretrained_networks/' # Directory of pretrained models
head_arch: 'linear' # Linear or nonlinear classifier ['linear', 'nonlinear']
alpha: 1.0        # Weight of concept loss in joint training
encoder_arch: 'resnet18' # Encoder architecture ['resnet18', 'simple_CNN', 'FCNN']

# Training
training_mode: 'joint' # Optimization method: ['joint', 'sequential', 'independent']
validate_per_epoch: 30 # Periodicity to evaluate the model
learning_rate: 0.0001 # Learning rate in the optimization
optimizer: 'adam' # Optimizer: ['sgd','adam']
decrease_every: 150 # Frequency of the learning rate decrease
lr_divisor: 2 # Rate of the learning rate decrease
weight_decay: 0 # Weight decay
train_batch_size: 256 # Batch size for the training set
val_batch_size: 256 # Batch size for the validation and test sets

compile: False # Whether to use Torch.compile to optimize the model. Requires Triton

j_epochs: 300 # Number of epochs for joint training
c_epochs: 200 # Number of epochs for first stage training in sequential & independent training
t_epochs: 100 # Number of epochs for second stage training in sequential & independent training

# MCMC
num_monte_carlo: 100 # Number of MC samples when drawing from concept distribution
straight_through: True # If Gumbel-Softmax, whether to do straight-through or not [True, False]