# @package _global_

# Mirror `scale_decode` - pretrain a stable decoder, while varying the unsup source
# Tag refers to unsup source

defaults:
  - /model: flat_enc_dec
  - /model/task:
    - joint_bhvr_decode_flat
  - /dataset: flat
  - /train: single_session_exp1
dataset:
  data_keys:
  - DataKey.spikes
  - DataKey.bhvr_vel
  # datasets: - will be specified in calling script
model:
  causal: true
  neurons_per_token: 32
  decoder_context_integration: 'cross_attn'
  task:
    decode_time_pool: ""
    task_weights: [1.0, 0.1]
    mask_ratio: 0.5
    behavior_lag: 0 # No lag for human data. For parity with ongoing exps.
    decode_normalizer: REDACT_obs_zscore.pt

  accelerate_new_params: 10.0 # We're introducing a whole new readout layer...
  lr_schedule: 'fixed'
  lr_init: 4e-5
  val_iters: 10

  extra_task_embed_ckpt: '{shared_dir}/pretrained/pretrained_unsup.ckpt'
train:
  autoscale_batch_size: false
  batch_size: 8 # Assuming we have ~50-100 trials.
  patience: 75
experiment_set: 'online_bci'
inherit_exp: 'REDACT_v3/decode_full'