# @package _global_

defaults:
  - /model: flat_enc_dec
  - /dataset: flat
model:
  hidden_size: 384
  transformer:
    n_layers: 12
    debug_force_nonlearned_position: True
  causal: false
  subject_embed_strategy: EmbedStrat.token
  task:
    mask_ratio: 0.25 # squeeze in an 2x batch
  neurons_per_token: 32
dataset:
  neurons_per_token: 32

  max_arrays: 1
  max_channels: 288

  scale_ratio: 1.0
  scale_limit_per_eval_session: 300 # no limit

  datasets:
  - odoherty_rtt-Indy.*
  - mc_rtt
  eval_datasets:
  - mc_rtt

train:
  # accumulate_batches: 4 # auto = 128 (12G GPUs), effective = 256
  autoscale_batch_size: false
  batch_size: 64

notes: "If 300 trials deserves 128 (0.5M params), this 30k = 128 * 100 -> 12800, ~50M params. Except now we kill dropout, so more like "