# Training configuration

device: cuda  # or cpu

# Data configuration
data:
  train_path: data/train_gp_data
  val_path: data/val_gp_data  
  num_workers: 0
  batch_size: 512 
  val_subset_size: 0  

# Model configuration
model:
  dim_x: 1
  dim_y: 1
  dim_model: 128
  max_buffer_size: 8 
  num_target_points: 64
  targets_block_size_for_buffer_attend: 8
  q_block_size: 128
  kv_block_size: 128
  
  embedder:
    hidden_dim: 256  
    depth: 3
  
  backbone:
    num_layers: 6  
    num_heads: 4  
    dim_feedforward: 256
    dropout: 0.0
  
  head:
    dim_feedforward: 256  
    num_components: 20

# Optimizer configuration
optimizer:
  name: adamw
  lr: 1e-4 
  betas: [0.9, 0.999]
  weight_decay: 0.1

# Scheduler configuration
scheduler:
  use_scheduler: true
  name: cosine_with_warmup
  warmup_ratio: 0.1

# Training configuration
training:
  num_epochs: 50
  grad_clip: 1.0
  compile_model: true
  compile_mask: true  # Compile the mask creation function
  compile_mode: default  
  fullgraph: false 
  dynamic: false 
  prewarm_compilation: true  # Pre-compile all 8 shape variants
  use_amp: true
  amp_dtype: bfloat16
  val_interval: 1  

# Checkpoint configuration
checkpoint:
  save_dir: checkpoints/${now:%Y-%m-%d}/${now:%H-%M-%S}
  save_interval: 10  # Save every N epochs

# Logging configuration
logging:
  use_wandb: true
  project: ace-training
  run_name: ace-${now:%Y%m%d-%H%M%S}
  log_interval: 50  # Log every N steps
  tags: ["ace", "gp-data"]