defaults:
  - _self_

data_type: dsprites
debiasing_method: circe

data:
  datamodule:
    _target_: data.dSprites.dsprites_dataset.DspritesDataModule
    h5_path: /datasets/dSprites/sprites.h5
    train_csv: /datasets/dSprites/train.csv
    val_csv: /datasets/dSprites/val.csv
    test_csv: /datasets/dSprites/test.csv
    num_workers: 8
    batch_size: 256
    circe_enabled: true
    circe_heldout_size: 0.1  # or e.g. 500 for absolute size

experiment_dir: ${data_type}/${debiasing_method}/target_${predictor.target}/n_outputs_${predictor.resnet_cfg.n_outputs}_lr_start_${predictor.lr_start}_lr_end_${predictor.lr_end}

hydra:
  job:
    chdir: True
  sweep:
    dir: results/simulation/circe/
    subdir: ${experiment_dir}/${hydra.job.id}
  sweeper:
    params:
      +predictor.target: label_c
      +predictor.lr_start: 0.0001
      +predictor.lr_end: 1e-08
      +predictor.resnet_cfg.n_basefilters: 64
      +predictor.circe_lambda: 50.0
      +predictor.sigma2_yhat: 0.01
      +predictor.sigma2_z: 0.01
      +predictor.sigma2_y: 0.1

seed: 3169

predictor:
  encoder_type: resnet_gn
  pred_head: non-linear
  protected_attributes: [x]
  resnet_cfg:
    in_channels: 1
    n_outputs: 1
    n_blocks: 3
    bn_momentum: 0.1
    dropout_p: 0.0
    no_pooling: True

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: 1
  max_epochs: 100
  enable_progress_bar: True
  detect_anomaly: False
  log_every_n_steps: 1
  enable_checkpointing: true
  limit_train_batches: 1.0
  limit_val_batches: 1.0

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val/loss
    mode: min
    save_top_k: 1
    dirpath: chkpts
    filename: "epoch-{epoch}-val_loss-{val/loss:.4f}"
    verbose: True
    save_last: True
    auto_insert_metric_name: False
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: 'epoch'

logger:
  wandb:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: "dSprites Debug"
    name: "${experiment_dir}"
    offline: False
