defaults:
  - default
  - override causal_discovery: ges
  - override llm: gpt
  - override rag: standard
  - _self_  
  # override this to use the new list sweeper:
  - override hydra/sweeper: list
 
hydra:
  mode: MULTIRUN
  sweep:
    dir: "outputs/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}_cem_bn"
  sweeper:
    # standard grid search
    grid_params:
      # possible datsets are:
      # asia, sachs, insurance, alarm, hailfinder, cub_causal_struct, synthetic
      # colormnist, colormnist_ood, celeba, celeba_unfair, siim_pneumothorax
      # possible models are:
      # blackbox, blackbox_multi, cbm_linear, cbm_mlp, cem, c2bm
      model: cem
      seed: 1,2,3,4,5
    # additional list sweeper
    list_params:
      dataset: asia, sachs, insurance, alarm, hailfinder
      model.hidden_size:      [32, 32, 32, 64, 64]
      model.concept_loss_weight: [0.8, 0.8, 0.8, 0.8, 0.8]
      model.concept_hidden_size: [4, 4, 4, 8, 4]
      model.dropout:          [0.5, 0.5, 0.5, 0.5, 0.5]
      engine.optim_kwargs.lr: [0.00075, 0.0004, 0.00075, 0.00075, 0.0004]

policy: nodes_true # levels_true, levels_pred, nodes_true, nodes_pred, random

dataset:
  load_embeddings: false
  load_graph: false
  load_true_graph: false
  batch_size: 512  # 128 (Pneumo, all SCBM), 512 (all others)
  autoencoder:
    noise: 0.5  # for bndatasets only
  num_workers: 0

engine:
  intervention_prob: 0.8
  test_interv_noise: 0.8   # for bndatasets only

trainer:
  logger: null  # wandb, null
  devices: [0]
  max_epochs: 500
  patience: 30

rag:
  verbose: false
  source: arxiv # arxiv or custom (arxiv + www scraper)

notes: neurips_test