# An example configuration file for training a CBM on the CheXpert

# --------------------------
#         Experiment
# --------------------------
experiment_name: 'cbm-cheXpert'                                 # name of the experiment
run_name: 'test'                                                # name of the run
seed: 42                                                        # random generator seed
device: 'cuda'                                                  # device to use for PyTorch: 'cpu' or 'cuda'
workers: 2                                                      # number of worker processes
# TODO: specify relevant directories
log_directory: ...                                              # directory for saving logs and model checkpoints
model_directory: './pretrained_networks/resnet/'                # directory with pretrained networks, s.a. ResNet-18

suppress_warnings: True

# --------------------------
#         Dataset
# --------------------------
dataset: 'cheXpert'                                             # name of the dataset
# TODO: specify relevant directories
data_path: ...                                                  # path to the dataset folder
num_classes: 2                                                  # number of the classes of the target variable
num_concepts: 13                                                # number of the concept variables

# --------------------------
#         Model
# --------------------------
model: 'cbm'                                                    # model's name
encoder_arch: 'resnet18'                                        # encoder architecture
num_hidden_y: 256                                               # number of hidden units
head_arch: 'FCNN'                                               # architecture of the prediction head

training_mode: 'joint'                                          # optimization methods for the CBM: 'sequential' or 'joint'
alpha: 1.0                                                      # alpha-parameter for the joint loss

# --------------------------
#         Training
# --------------------------
j_epochs: 350                                                   # number of training epochs in the joint optimization
j_learning_rate: 0.0001                                         # learning rate in the joint optimization

train_batch_size: 64                                            # batch size for the training set
val_batch_size: 64                                              # batch size for the validation and test sets

validate_per_epoch: 30                                          # periodicity to evaluate the model
pth_store_frequency: 50                                         # periodicity to save model checkpoints

optimizer: 'adam'                                               # optimizer: 'sgd' or 'adam'

decrease_every: 150                                             # frequency of the learning rate decrease
lr_divisor: 2                                                   # rate of the learning rate decrease

weight_decay: 0                                                 # weight decay
