# @package _global_

# H2 has data up to 47 seconds. Most configs thus far were prepared for <30s data - we need to change a few config for long context.
defaults:
  - ../_default
model:
  # Increase patch size to reduce token count, to fit ultra-long context in memory
  # At 192 channels this yields 3 tokens / timestep
  neurons_per_token: 64
  dropout: 0.4 # adopted from BrainGate codebases, also heuristically the most effective setting from v4.
  # No kinematic inputs.
  kinematic_token_maskout_schedule: constant
  kinematic_token_maskout: 0
  assert_batch_uniform: True # Only implemented for spike infill and seq decoding
  task:
    tasks:
    - ModelTask.spike_infill
    - ModelTask.seq_decoding
    task_modality_input:
    - 0
    - 0 # seq decode consumes spikes as well
    task_weights: [1.0, 0.1]
    decode_time_pool: mean
    decode_quantize_classes: 32 # 31 chars + 1 blank
    metrics:
    - Metric.cer
    # Different patch size requires re-initing spike IO
    delete_params_on_transfer:
    - spike_infill:readin.weight
    - spike_infill:out.0.weight
    - spike_infill:out.0.bias
  transformer:
      max_trial_length: 3000 # 60 seconds
      n_layers: 6
  
  # Notes on LR
  # - Init
  # -- BrainGate LR is way high at 1e-2. Believe that's unstable for us. Retaining 3-5e-4.
  # - Ramp: Willett does not ramp, we keep default NDT3 ramp 
  # (Hypothesis: Plausible slower ramp influences stability, no budget to explore right now)
  # - Decay: Smoketests run for longer than standard 1-1.5K epochs. Decay to 2K (for base/scratch)
  # -- At current dataset size, 1K epochs ~ 8K steps, cf. Willett 21 decays 50K steps (note our "steps" is epochs) in PT, so we're still well under full PT.
  # A compromise b/n converging to 0 and constant. This is "some" decay in the relevant timescale
  lr_ramp_steps: 200 # Early evidence says ramping fast (<=100) can lead to suboptimal, particularly for FT. Let's take time. 200 seems sufficient
  lr_decay_steps: 3000 # Decay to 0 does seem helpful for FT
  # lr_decay_steps: 10000
  # lr_decay_steps: 2000

dataset:
  max_tokens: 8192 # sufficient - Observed max 2348 * 3 = 7044 tokens
  max_trial_length: 3000

  verify_preproc: True # While we're iterating
  max_channels: 640 # needed to preserve spatial embedding count for pretraining.
  explicit_alias_to_session: true
  assert_no_crop: True
  count_kinematic_in_token_limit: False
  datasets:
  - falcon_FALCONH2.*
  falcon_h2:
    respect_trial_boundaries: True
    minmax: False
    subsample_h2: 2 # Subsampling inherited from BrainGate baselines e.g. https://github.com/cffan/CORP/blob/master/NeuralDecoder/neuralDecoder/configs/model/gru_stack_handwriting.yaml#L26
  data_keys:
  - DataKey.spikes
  - DataKey.bhvr_vel
  max_length_ms: 60000 # Observed max 2348 timesteps
train:
  early_stop_metric: val_cer
  epochs: 3000 # vs std 1000, these tasks often run longer.
  patience: 2000 # Just keep training for now... we have no idea how long this grokking takes
  batch_size: 16 # 8 on 24G at 8K tokens, 16 on 40G IG
  max_batch_size: 128
  effective_batch_size: 128
sweep_cfg: 'high_ft_ss' # Just one seed
save_r2: False
save_cer: True