# @package _global_
# Place this file in: configs/experiment/ecrt.yaml
# e-CRT baseline for conditional independence testing

defaults:
  - /data: carinsurance

wandb:
  disabled: false
  group: "${train.model_x_mode}_${wandb.task}_${data.state}_${data.n_vals}_${train.name}_${train.model_a_type}_bs${data.samples}" #_${data.state}
  task: "carinsurance"
  tags: ["icml", "${train.name}", "random_seed", "bias"]

data:
  data_seed: ${train.seed}
  samples: 20  
  type: "type2"
  state: "ca" # ca, il, mo, tx
  company_idx: 0  # 0-based index; ca:[0-20], il:[0-33], mo:[0-24], tx:[0-17]
  n_vals: 20

train:
  name: "ecrt"
  seed: 0
  seqs: 100
  T: 0
  alpha: 0.05

  # e-CRT specific parameters
  K: 5                  # Number of tilde_a samples per test point
  model_x_mode: "online"  # 'online', 'pseudo_model_x', or 'model_x'
  noise_std: 1.0         # Std for Gaussian noise when sampling tilde_a
  pretrain_samples: 3000  # Number of samples for pretraining (if not online)

  # Model architecture
  dropout: 0.1
  layer_norm: true

  # Separate hyperparameters for each model (override with tuned values)
  model_a_type: "mean_estimator"  # "mean_estimator" or "gmmn"
  model_a_hidden_dims: [128, 64]
  model_a_lr: 0.01  # Override with tuned value
  model_a_weight_decay: 0.0  # Override with tuned value

  ecrt_model_hidden_dims: [128]
  ecrt_model_lr: 0.01  # Override with tuned value
  ecrt_model_weight_decay: 1e-4  # Override with tuned value

  # Training parameters
  epochs: 500
  batch_size: ${data.samples}
  earlystopping:
    patience: 30
    delta: 0.0
