# @package _global_

defaults:
  - /model: flat_enc_dec
  - /model/task: joint_heldout_decode
  - /train: nlb
  - /dataset: flat
dataset:
  datasets:
  - mc_maze_small
  data_keys:
  - DataKey.spikes
  - DataKey.heldout_spikes
  neurons_per_token: 128
  max_channels: 256
model:
  hidden_size: 128
  lr_ramp_steps: 3000
  lr_decay_steps: 100000
  neurons_per_token: 128

  dropout: 0.6
  transformer:
    dropout: 0.6
    n_layers: 4
    n_heads: 2
    debug_force_nonlearned_position: True
    debug_override_dropout_io: True
  task:
    mask_ratio: 0.25
    query_heldout: 45
train:
  autoscale_batch_size: false
  batch_size: 64
  patience: 5000