# @package _global_

# We pick a small subset of sessions with >0.5 R2 from Nic/Will's experiments to train a probe on.
# Due to small size we drop cross attn
# We focus on CRS02b

defaults:
  - /model: flat_enc_dec
  - /model/task:
    - bhvr_decode_flat
  - /dataset: flat
dataset:
  eval_ratio: 0.2
  data_keys:
  - DataKey.spikes
  - DataKey.bhvr_vel

  datasets:
  - observation_CRS02bLab_session_1809.*
  - observation_CRS02bLab_session_1811.*
  # - observation_CRS02bLab_session_1816.*
  - observation_CRS02bLab_session_1820.*
  - observation_CRS02bLab_session_1821.*
  - observation_CRS02bLab_session_1827.*
  - observation_CRS02bLab_session_1835.*
  - observation_CRS02bLab_session_1836.*
  # - observation_CRS02bLab_session_1845.*
  - observation_CRS02bLab_session_1848.*
  - observation_CRS02bLab_session_1849.*
  - observation_CRS02bLab_session_1865.*
  - observation_CRS02bLab_session_1866.*
  - observation_CRS02bLab_session_1868.*
  - observation_CRS02bLab_session_1871.*
  - observation_CRS02bLab_session_1880.*
  - observation_CRS02bLab_session_1883.*
  # - observation_CRS02bLab_session_1885.*
  - observation_CRS02bLab_session_1889.*

  # - observation_CRS02bLab_session_18.*
  # - observation_CRS07Lab_session_15.*
  exclude_datasets:
  - observation_CRS02bLab_session_19.* # test set
  # Bad R2s
  - observation_CRS02bLab_session_1877.*
  - observation_CRS02bLab_session_1865.*
  - observation_CRS02bLab_session_1858.*
  - observation_CRS02bLab_session_1851.*

  # - observation_CRS07Lab_session_16.*
model:
  causal: true
  neurons_per_token: 32
  # decoder_context_integration: 'cross_attn'
  task:
    decode_time_pool: ""
    task_weights: [1.0, 0.1]
    behavior_lag: 40 # updated from 120, swept in an unreported experiment. 40-100 seems reasonably robust.
    decode_normalizer: REDACT_obs_zscore.pt
  lr_ramp_steps: 50
train:
  patience: 50
inherit_exp: REDACT_v2