# @package _global_

defaults:
  - /model: flat_enc_dec
  # - /model: pretrain_small
  - /train: nlb
  - /dataset: flat
  # - /dataset: maze_nlb
dataset:
  datasets:
  - mc_maze_med
  eval_datasets:
  - mc_maze_medium
  eval_ratio: 0.2
  nlb_maze:
    heldout_neurons: 45
  data_keys:
  - DataKey.spikes
  - DataKey.heldout_spikes
  bin_size_ms: 20
  max_channels: 128
  neurons_per_token: 128
model:
  neurons_per_token: 128
  session_embed_strategy: EmbedStrat.none
  session_embed_token_count: 1
  # session_embed_token_count: 8
  task_embed_strategy: EmbedStrat.none
  # task_embed_strategy: EmbedStrat.token
  subject_embed_strategy: EmbedStrat.none
  # subject_embed_strategy: EmbedStrat.token
  hidden_size: 128

  dropout: 0.8
  decoder_layers: 5
  transformer:
    dropout: 0.6
    n_layers: 1 # since we use 2 decoder layers
    n_heads: 2
    debug_force_nonlearned_position: True
    debug_override_dropout_io: True
  task:
    task_weights: [1., 1.]
    tasks:
    - ModelTask.shuffle_infill
    - ModelTask.heldout_decoding
    mask_ratio: 0.25
    query_heldout: 45 # mc_maze maxed at this
    metrics:
    - Metric.co_bps
    - Metric.block_co_bps
    decode_time_pool: 'mean' # shouldn't matter
    decode_use_shuffle_backbone: true

  lr_ramp_steps: 3000
  lr_decay_steps: 100000
  force_zero_mask: True

train:
  patience: 5000
notes: "11 + Add force zero mask"
# This switches masking strategy.