# @package _global_

tag: REDACT_decode_task_scratch

defaults:
  - /model: pretrain
  # - /model: pretrain_2x
  - /train: pretrain
dataset:
  datasets:
  - CRS02bHome.data.*
  eval_datasets:
  - CRS02bHome.data.00437
  data_keys:
  - DataKey.spikes
  - DataKey.bhvr_vel
  REDACT_co:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  observation:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  ortho:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  fbc:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  max_length_ms: 4000
  max_arrays: 2
  bin_size_ms: 20
  max_channels: 96
  behavior_dim: 2
  meta_keys:
  - MetaKey.unique
  - MetaKey.session
  - MetaKey.array
  - MetaKey.subject
  - MetaKey.task
model:
  causal: True
  # accelerate_new_params: 10.0
  lr_init: 2e-4
  # lr_init: 5e-5
  # lr_schedule: fixed
  task:
    tasks:
    - ModelTask.kinematic_decoding
    metrics:
    - Metric.kinematic_r2
  task_embed_strategy: EmbedStrat.token
  subject_embed_strategy: EmbedStrat.none
  # subject_embed_strategy: EmbedStrat.token
  array_embed_strategy: EmbedStrat.none
train:
  batch_size: 32 # small batch, trying to stabilize training
  autoscale_batch_size: false
  epochs: 100000
  patience: 5000
sweep_cfg: small_wide