data:
  name: dsprites
  datamodule:
    _target_: data.dSprites.dsprites_dataset_multi.DspritesDataModule
    h5_path: ${dataset_path}/sprites.h5
    train_csv: ${dataset_path}/train.csv
    val_csv: ${dataset_path}/val.csv
    test_csv: ${dataset_path}/test.csv
    num_workers: 8
    batch_size: 256

callbacks:
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: 'epoch'

predictor:
  encoder_type: resnet_18
  pred_head: non-linear
  target: label
  protected_attributes: [x]
  resnet_cfg:
    n_outputs: 1
    pretrained: False
  optimizer_cfg:
    name: radam
    lr_end: 1e-08
    weight_decay: 0.01

logger:
  wandb:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: "DISCO dSprites Multi Single Shape"
    name: "${debiasing_method}/${data.name}/${experiment_dir}"
    offline: False

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 1
  max_epochs: 100
  enable_progress_bar: True
  detect_anomaly: False
  log_every_n_steps: 1
  enable_checkpointing: true

hydra:
  job:
    chdir: True
  sweep:
    dir: results_multi_single_shape/${debiasing_method}/${data.name}/
    subdir: ${experiment_dir}/
  sweeper:
    params:
      +predictor.optimizer_cfg.lr: 0.0001
