# @package _global_

defaults:
  - /model: flat_enc_dec
  - /model/task: joint_heldout_decode
  - /train: nlb
  - /dataset: flat
dataset:
  datasets:
  - mc_maze_large
  data_keys:
  - DataKey.spikes
  - DataKey.heldout_spikes
  neurons_per_token: 128
  max_channels: 256
model:
  hidden_size: 128
  lr_ramp_steps: 3000
  lr_decay_steps: 100000
  neurons_per_token: 128

  dropout: 0.3
  transformer:
    dropout: 0.3
    n_heads: 2
    debug_force_nonlearned_position: True
    debug_override_dropout_io: True
  task:
    mask_ratio: 0.25
    query_heldout: 45
    decode_time_pool: 'mean'
    decode_strategy: EmbedStrat.project
    decode_use_shuffle_backbone: True
train:
  autoscale_batch_size: false
  batch_size: 64
  patience: 5000
notes: 'Decrease dropout to 0.3, bc perhaps ndt2 is "harder"'