# @package _global_

defaults:
  - /model: flat_enc_dec
  - /model/task:
    - joint_bhvr_decode_flat
  - /dataset: flat
dataset:
  datasets:
  - odoherty_rtt-Indy.*
  exclude_datasets:
  - odoherty_rtt-Indy-20160407_02 # First indy session
  - odoherty_rtt-Indy-20160627_01 # Original
  - odoherty_rtt-Indy-20161005_06
  - odoherty_rtt-Indy-20161026_03
  - odoherty_rtt-Indy-20170131_02 # Last indy sesison

  data_keys:
  - DataKey.spikes
  - DataKey.bhvr_vel
  odoherty_rtt:
    arrays: ['Indy-M1', 'Loco-M1']
    include_sorted: False
train:
  autoscale_batch_size: False
  batch_size: 512
model:
  causal: true
  neurons_per_token: 32
  task:
    mask_ratio: 0.2 # ! We tune this down maybe to improve decode (now that we're joint it matters)
    behavior_lag: 120
  lr_ramp_steps: 50 # assumed shorter schedule
  # Don't accelerate, we're going for TAPT (long-ish)
  # accelerate_new_params: 10.0 # We're introducing a whole new readout layer...
  # lr_schedule: 'fixed'