# @package _global_
defaults:
  - /model: flat_enc_dec
  - /model/task: joint_heldout_decode
  - /dataset: flat
  - /train: pretrain
model:
  session_embed_token_count: 1
  causal: False
  hidden_size: 128
  neurons_per_token: 32
  task:
    mask_ratio: 0.25
    query_heldout: 45

  lr_init: 5e-5
  lr_ramp_steps: 1000
  lr_decay_steps: 10000
  accelerate_new_params: 10.0
  tune_decay: 0.75 # per Kaiming MAE
dataset:
  data_keys:
  - DataKey.spikes
  - DataKey.heldout_spikes

  datasets:
  - mc_rtt
  eval_datasets:
  - mc_rtt
  eval_ratio: 0.2
  max_channels: 288
  neurons_per_token: 32
train:
  autoscale_batch_size: False
  batch_size: 64
init_from_id: rtt_f32_parity-kl5m5jnj
init_tag: val_loss