# EEG training configuration
# TORCH_COMPILE_DISABLE=1 python train.py --config-name train_eeg device=cpu data.train_path='data/eeg_dataset/train ' data.val_path='data/eeg_dataset/val'
defaults:
  - _self_

# System configuration
device: "cuda"
seed: 42

# Data configuration
data:
  train_path: /data/eeg_dataset/train
  val_path: /data/eeg_dataset/val
  num_workers: 4
  batch_size: null  # Each file is already a complete batch
  val_subset_size: 100  # Use first 100 validation batches

# Model architecture
model:
  dim_x: 1  # Time dimension
  dim_y: 7  # 7 EEG channels (FZ, F1, F2, F3, F4, F5, F6)
  dim_model: 128
  max_buffer_size: 16  # Fixed for EEG
  num_target_points: 240  # Adjust based on total_points - nc - 8
  targets_block_size_for_buffer_attend: 16
  q_block_size: 128
  kv_block_size: 128
  
  embedder:
    hidden_dim: 128
    depth: 3
    pos_emb_init: true
  
  backbone:
    num_layers: 6
    num_heads: 4
    dim_feedforward: 256
    dropout: 0.0
  
  head:
    dim_feedforward: 256
    num_components: 8
    std_min: 5e-3
    # Use MultiChannelMixtureGaussian instead of MixtureGaussian
    type: "MultiChannelMixtureGaussian"

# Optimizer configuration
optimizer:
  name: adam
  lr: 1e-4
  betas: [0.9, 0.999]
  weight_decay: 0.0

# Learning rate scheduler configuration
scheduler:
  use_scheduler: true
  name: cosine_with_warmup  # Options: cosine, cosine_with_warmup
  warmup_ratio: 0.1  # 10% of total steps for warmup

# Training configuration
training:
  num_epochs: 32
  grad_clip: 1.0  # can also use 0.5 reduced from 1.0 for stability
  compile_model: true
  compile_mask: true
  compile_mode: default
  fullgraph: false
  dynamic: false
  prewarm_compilation: false
  use_amp: false  # Disabled for numerical stability with float64 data
  amp_dtype: bfloat16
  val_interval: 1

# Checkpointing
checkpoint:
  save_dir: "./checkpoints/eeg"
  save_interval: 10
  
# Logging configuration
logging:
  use_wandb: true
  project: fast-buffer-np
  run_name: eeg-${now:%Y%m%d-%H%M%S}
  log_interval: 50
  tags: ["ace", "eeg"]