defaults:
  - _self_

data_type: waterbirds
debiasing_method: hscic

data:
  datamodule:
    _target_: data.waterbirds.waterbirds.WaterbirdsDataModule
    data_dir: /datasets/waterbirds_bias/data
    train_csv: /datasets/waterbirds_bias/train.csv
    val_csv: /datasets/waterbirds_bias/val.csv
    test_csv: /datasets/waterbirds_bias/test.csv
    num_workers: 8
    batch_size: 64

experiment_dir: ${data_type}/${debiasing_method}/target_${predictor.target}/n_outputs_${predictor.resnet_cfg.n_outputs}_lr_start_${predictor.lr_start}_lr_end_${predictor.lr_end}

hydra:
  job:
    chdir: True
  sweep:
    dir: results/simulation/hscic/
    subdir: ${experiment_dir}/${hydra.job.id}
  sweeper:
    params:
      +predictor.target: label
      +predictor.lr_start: 0.001
      +predictor.lr_end: 1e-08
      +predictor.sigma2_yhat: 0.1
      +predictor.sigma2_z: 0.1
      +predictor.sigma2_y: 0.1
      +predictor.hscic_lambda: 10.0
      +predictor.ridge_lambda: 0.01

seed: 3169

predictor:
  encoder_type: resnet_18
  pred_head: non-linear
  protected_attributes: [b]
  resnet_cfg:
    n_outputs: 1

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 1
  max_epochs: 100
  enable_progress_bar: True
  detect_anomaly: False
  log_every_n_steps: 1
  enable_checkpointing: true
  limit_train_batches: 1.0
  limit_val_batches: 1.0

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val/loss
    mode: min
    save_top_k: 1
    dirpath: chkpts
    filename: "epoch-{epoch}-val_loss-{val/loss:.4f}"
    verbose: True
    save_last: True
    auto_insert_metric_name: False
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: 'epoch'

logger:
  wandb:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: "Waterbirds"
    name: "${experiment_dir}"
    offline: False
