# @package _global_
# Place this file in: configs/experiment/ecrt.yaml
# DAVT baseline for conditional independence testing

defaults:
  - /data: gaussiancit

wandb:
  disabled: false
  group: "${train.model_x_mode}_${wandb.task}_d${data.d}_${train.name}_${train.model_a_type}_bs${data.samples}"
  task: "gaussian"
  tags: ["icml", "${train.name}", "${train.model_x_mode}", "final-dp03-sb-bias", "corrected_model_x"]

data:
  data_seed: ${train.seed}
  samples: 20  # Match your gaussian.yaml default
  d: 19        # Match your gaussian.yaml default
  type: "type2"

train:
  name: "davt"
  seed: 0
  seqs: 100
  T: 0
  alpha: 0.05

  model_x_mode: "model_x"  # 'online', 'pseudo_model_x', or 'model_x'
  noise_std: 1.0          # Std for Gaussian noise when sampling tilde_a
  pretrain_samples: 3000  # Number of samples for pretraining (if not online)

  # Model architecture
  dropout: 0.1
  layer_norm: true
  # Separate hyperparameters for each model (override with tuned values)
  model_a_type: "mean_estimator"  # "mean_estimator" or "gmmn"
  model_a_hidden_dims: [32]
  model_a_lr: 0.01  # Override with tuned value
  model_a_weight_decay: 0.0  # Override with tuned value

  davt_model_hidden_dims: [128]
  davt_model_lr: 0.0005  # Override with tuned value
  davt_model_weight_decay: 0.0  # Override with tuned value
  davt_dropout: 0.3
  
  # Training parameters
  epochs: 500
  batch_size: 100
  earlystopping:
    patience: 30
    delta: 0.0
