data:
  name: simulation
  datamodule:
    _target_: data.sim_data.sim_dataset_3blob.SimDataModule
    data_dir: ${dataset_path}/sim_data/sim_all.h5
    split_dir: ${dataset_path}/sim_data/split.json
    num_workers: 8
    batch_size: 128

callbacks:
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: 'epoch'

predictor:
  encoder_type: resnet_gn
  pred_head: non-linear
  target: label_c
  protected_attributes: [nf_std]
  resnet_cfg:
    in_channels: 1
    n_outputs: 1
    n_blocks: 2
    n_basefilters: 64
    bn_momentum: 0.1
    dropout_p: 0.0
    no_pooling: True
  optimizer_cfg:
    name: radam
    lr_end: 1e-08
    weight_decay: 0.01

logger:
  wandb:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: "DISCO Simulation 3blob"
    name: "${debiasing_method}/${data.name}/${experiment_dir}"
    offline: False

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 1
  max_epochs: 100
  enable_progress_bar: True
  detect_anomaly: False
  log_every_n_steps: 1
  enable_checkpointing: true

hydra:
  job:
    chdir: True
  sweep:
    dir: results/${debiasing_method}/${data.name}/
    subdir: ${experiment_dir}/
  sweeper:
    params:
      +predictor.optimizer_cfg.lr: 0.0005