# @package _global_
# Place this file in: configs/experiment/ecrt.yaml
# e-CRT baseline for conditional independence testing

defaults:
  - /data: sin_d3_000

wandb:
  disabled: false
  group: "${train.model_x_mode}_${wandb.task}_d${data.d}_coord${data.ca_dim_idx}${data.cb_dim_idx}${data.cr_dim_idx}_${train.name}_${train.model_a_type}_bs${data.samples}"
  task: "sin"
  tags: ["icml", "${train.name}", "${train.model_x_mode}", "random_seed", "bias", "corrected_model_x"]

data:
  data_seed: ${train.seed}
  samples: 20  
  type: "type2"

train:
  name: "ecrt"
  seed: 0
  seqs: 100
  T: 0
  alpha: 0.05

  # e-CRT specific parameters
  K: 5                  # Number of tilde_a samples per test point
  model_x_mode: "model_x"  # 'online', 'pseudo_model_x', or 'model_x'
  noise_std: ${data.alpha}         # Std for Gaussian noise when sampling tilde_a
  pretrain_samples: 3000  # Number of samples for pretraining (if not online)

  # Model architecture
  dropout: 0.1
  layer_norm: true

  # Separate hyperparameters for each model (override with tuned values)
  model_a_type: "mean_estimator"  # "mean_estimator" or "gmmn"
  model_a_hidden_dims: [64, 32]
  model_a_lr: 0.01  # Override with tuned value
  model_a_weight_decay: 1e-4  # Override with tuned value

  ecrt_model_hidden_dims: [128, 64, 32]
  ecrt_model_lr: 0.01  # Override with tuned value
  ecrt_model_weight_decay: 1e-4  # Override with tuned value

  # Training parameters
  epochs: 500
  batch_size: ${data.samples}
  earlystopping:
    patience: 30
    delta: 0.0
