# An example configuration file for training a black-box classifier on the synthetic data

# --------------------------
#         Experiment
# --------------------------
experiment_name: 'black-box-synthetic-default'            # name of the experiment
run_name: 'test'                                          # name of the run
seed: 42                                                  # random generator seed
device: 'cuda'                                            # device to use for PyTorch: 'cpu' or 'cuda'
workers: 2                                                # number of worker processes
# TODO: specify relevant directories below
log_directory: ...                                        # directory for saving logs and model checkpoints
model_directory: ...                                      # directory with pretrained networks, s.a. ResNet-18

suppress_warnings: True

# --------------------------
#         Dataset
# --------------------------
dataset: 'synthetic'                                      # name of the dataset
sim_type: 'default'                                       # data-generating mechanism
num_classes: 2                                            # number of the classes of the target variable
num_concepts: 30                                          # number of the concept variables
num_covariates: 1500                                      # number of covariates
num_points: 50000                                         # number of data points

# --------------------------
#         Model
# --------------------------
model: 'black-box'                                        # model's name
encoder_arch: 'FCNN'                                      # encoder architecture
num_hidden_z: 256                                         # number of hidden units
num_hidden_y: 256                                         # number of hidden units
act_z: 'ReLU'                                             # activation in the intermediate layers
head_arch: 'FCNN'                                         # architecture of the prediction head

concatenation: False                                      # fine-tuning black box by concatenating concepts and zs

training_mode: 'joint'                                    # optimization method for the CBM: 'sequential' or 'joint'

# --------------------------
#         Training
# --------------------------
num_epochs: 100                                           # number of training epochs in the joint optimization
j_learning_rate: 0.0001                                   # learning rate in the joint optimization

train_batch_size: 64                                      # batch size for the training set
val_batch_size: 64                                        # batch size for the validation and test sets

validate_per_epoch: 30                                    # periodicity to evaluate the model
pth_store_frequency: 2000                                 # periodicity to save model checkpoints

optimizer: 'adam'                                         # optimizer: 'sgd' or 'adam'

decrease_every: 150                                       # frequency of the learning rate decrease
lr_divisor: 2                                             # rate of the learning rate decrease

weight_decay: 0                                           # weight decay