defaults:
  - _self_

data_type: dsprites
debiasing_method: cdisco

data:
  datamodule:
    _target_: data.dSprites.dsprites_dataset.DspritesDataModule
    h5_path: /datasets/dSprites/sprites.h5
    train_csv: /datasets/dSprites/train.csv
    val_csv: /datasets/dSprites/val.csv
    test_csv: /datasets/dSprites/test.csv
    num_workers: 8
    batch_size: 128

experiment_dir: /${data_type}/${debiasing_method}/target_${predictor.target}/in_channels_${predictor.resnet_cfg.in_channels}_n_outputs_${predictor.resnet_cfg.n_outputs}_lr_start_${predictor.lr_start}_lr_end_${predictor.lr_end}

hydra:
  job:
    chdir: True
  sweep:
    dir: results/simulation/cdisco/
    subdir: ${experiment_dir}/${hydra.job.id}
  sweeper:
    params:
      +predictor.target: label_c
      +predictor.lr_start: 0.0001
      +predictor.lr_end: 1e-08
      +predictor.resnet_cfg.n_basefilters: 64
      +predictor.bw: 0.1
      +predictor.cdcor_lambda: 0.05

seed: 3169

predictor:
  encoder_type: resnet_gn
  pred_head: non-linear
  protected_attributes: [x]
  resnet_cfg:
    in_channels: 1
    n_outputs: 1
    n_blocks: 3
    bn_momentum: 0.1
    dropout_p: 0.0
    no_pooling: True

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 1
  max_epochs: 100
  enable_progress_bar: True
  detect_anomaly: False
  log_every_n_steps: 1
  enable_checkpointing: true
  limit_train_batches: 1.0
  limit_val_batches: 1.0

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val/loss
    mode: min
    save_top_k: 1
    dirpath: chkpts
    filename: "epoch-{epoch}-val_loss-{val/loss:.4f}"
    verbose: True
    save_last: True
    auto_insert_metric_name: False
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: 'epoch'

logger:
  wandb:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: "dSprites Debug"
    name: "${experiment_dir}"
    offline: False