# @package _global_

tag: REDACT_obs_factor_decode

defaults:
  - /model: pretrain
  - /train: pretrain
dataset:
  datasets:
  - CRS02bHome.data.00329
  - CRS02bHome.data.00336
  - CRS02bHome.data.00339
  - CRS02bHome.data.00345
  - CRS02bHome.data.00360
  - CRS02bHome.data.00371
  - CRS02bHome.data.00402
  - CRS02bHome.data.00422
  - CRS02bHome.data.00424
  - CRS02bHome.data.00437
  # - CRS02bHome.data.00445
  eval_datasets:
  - CRS02bHome.data.00437
  data_keys:
  - DataKey.spikes
  - DataKey.bhvr_vel
  REDACT_co:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  observation:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  ortho:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  fbc:
    arrays:
    - CRS02b-lateral_m1
    - CRS02b-medial_m1
  max_length_ms: 4000
  max_arrays: 2
  bin_size_ms: 20
  max_channels: 96
  behavior_dim: 2
  meta_keys:
  - MetaKey.unique
  - MetaKey.session
  - MetaKey.array
  - MetaKey.subject
model:
  causal: True
  lr_init: 3e-4
  task:
    tasks:
    - ModelTask.kinematic_decoding
    metrics:
    - Metric.kinematic_r2
  subject_embed_strategy: EmbedStrat.token
  array_embed_strategy: EmbedStrat.none

  spike_embed_style: EmbedStrat.token
  neurons_per_token: 8 # we've got 192 tokens
  hidden_size: 128
  transform_space: true
  transformer:
    factorized_space_time: true
train:
  epochs: 100000
  patience: 5000
sweep_cfg: small_wide