device: cuda  # or cpu

data:
  # Online generation settings
  batch_size: 512             # tasks per batch
  num_batches_per_epoch: 2000 # iterations per epoch
  val_num_batches: 32         # online validation batches per epoch
  val_path: "."              # dummy to trigger val loader when val_num_batches>0
  num_workers: 2

  # SCM prior + normalization (matches offline generator semantics)
  d_list: [1,2,3,4,5,6,7,8,9,10]
  nc_list: [8,16,32,64,128,256,512,1024]
  num_buffer: 32
  num_target: 512
  normalize_x: true
  x_norm_method: power
  x_outlier_threshold: 4.0
  normalize_y: true
  dtype: float32
  seed: 123

model:
  # Tabular ACE with TabularEmbedder
  dim_x: 10
  dim_y: 1
  dim_model: 64
  ff_factor: 2
  max_buffer_size: 32
  num_target_points: 512
  targets_block_size_for_buffer_attend: 32
  q_block_size: 128
  kv_block_size: 128
  attending_chunks: 16

  embedder:
    type: tabular
    concat_cls: true
    max_dim_x: 10
    # Column (TabICL) encoder
    num_isab_blocks: 3         # col_num_blocks
    col_nhead: 4               # col_nhead
    num_inducing_points: 128   # col_num_inds
    # Row (RoPE) encoder
    num_layers: 3              # row_num_blocks
    row_nhead: 8               # row_nhead
    num_cls_tokens: 4          # row_num_cls
    row_rope_base: 100000

  backbone:
    num_layers: 12            # icl_num_blocks
    num_heads: 4
    dropout: 0.0

  head:
    type: MixtureGaussian
    num_components: 20
    std_min: 1e-3

  # Precreate masks for all (Nc,D) combos we sample online to avoid runtime mask builds
  precompile_shapes:
    - [40, 512]     # Nc=8   + Nb=32
    - [48, 512]     # Nc=16  + Nb=32
    - [64, 512]     # Nc=32  + Nb=32
    - [96, 512]     # Nc=64  + Nb=32
    - [160, 512]    # Nc=128 + Nb=32
    - [288, 512]    # Nc=256 + Nb=32
    - [544, 512]    # Nc=512 + Nb=32
    - [1056, 512]   # Nc=1024 + Nb=32

  include_diagonal_mask: false

optimizer:
  name: adamw
  lr: 1e-4
  betas: [0.9, 0.95]
  weight_decay: 0.0

scheduler:
  use_scheduler: true
  name: cosine_with_warmup
  warmup_ratio: 0.20        # retained for readability
  warmup_steps: 2000        # takes precedence (20k total → 10%)
  num_cycles: 1

training:
  num_epochs: 50
  grad_clip: 0.5
  compile_model: true
  compile_mask: true
  compile_mode: default
  fullgraph: false
  dynamic: false
  prewarm_compilation: true
  use_amp: false
  amp_dtype: bfloat16
  val_interval: 1
  max_steps: 20000  # optional global step cap; when set (>0) overrides epoch length
  val_step_interval: 250  # run validation every N steps (step-based)

checkpoint:
  save_dir: checkpoints/tabular_online_${now:%Y-%m-%d}/${now:%H-%M-%S}
  save_interval: 10

logging:
  use_wandb: true
  project: ace-tabular
  run_name: tabular-online-${now:%Y%m%d-%H%M%S}
  log_interval: 50
  tags: ["ace","tabular","online","scm","parquet"]
