# @package _global_
# Changes from v4
# - to 3s chop length
# - enable monitors
# - fixed epochs
# - mse as default loss

defaults:
  - /model:
    - flat_enc_dec
    - scale_history
    - largescale
  - /model/task:
    - decode_flat_v2 # spcify individually
  - /dataset:
    - flat
    - scale_history
  - /train:
    - largescale # for the main runs, not debuggin runs
model:
  arch: Architecture.flash_ndt
  max_neuron_count: 21
  log_token_proc_throughput: True

  lr_ramp_steps: 50
  lr_decay_steps: 500
  lr_interval: epoch
  lr_schedule: cosine_timm

  next_step_prediction: True
  kinematic_token_maskout_schedule: "random"
  kinematic_token_maskout_start: 1.0
  kinematic_token_maskout: 0.0  # 0.1

  task:
    constraint_mute: False
    return_mute: False

    tasks:
    - ModelTask.spike_infill
    - ModelTask.kinematic_linear # MSE head
    - ModelTask.constraints
    - ModelTask.return_context
    metrics:
    - Metric.kinematic_r2
    task_weights: [1.0, 1.0, 0., 0.] # So set that the rough dynamic range of losses are <1 OOM apart

    context_prompt_time_thresh: 0 # None
    context_prompt_time_thresh_min: 750 # Full

    # REDACT's logic for v5: PrefixLMs generally block prefix loss, but maintaining high prefix is bad.
    # prefix_ratio: 0.5
    # block_prefix_loss: True # Our prefix - not acausal. Just a place where highly autocorrelated kin is blocked
    # v4 settings
    prefix_ratio: 0.9
    block_prefix_loss: False

    encode_constraints: True
    use_constraint_cls: False # No need
    decode_quantize_classes: 512

    spike_loss: cross_entropy

  transformer:
    initializer_range: 0.0
    activation: swiglu
    rotary_position: True
    rotary_position_torch: True
    use_biases: False
    learnable_norm: False
    n_layers: 6
    max_trial_length: 1500 # Up to 30s
    n_heads: 8
    pre_norm: True
    qk_normalization: True
  hidden_size: 1024
  use_full_encode: True
  decoder_context_integration: 'cross_attn' # Literally not even used
  max_spatial_position: 48
dataset:
  pack_dense: True
  bin_size_ms: 20
  max_tokens: 4096
  max_length_ms: 2000
  max_trial_length: 100
  odoherty_rtt:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  REDACT_co:
    try_stitch_norm: True
    limit_kin_dims: 15 # We scrape up to 14 pos dims and force, rarely expect all to be active
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  miller:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  dyer_co:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  gallego_co:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  rouse:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  churchland_misc:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  delay_reach:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  churchland_maze:
    chop_size_ms: 2000 # ${dataset.max_length_ms}
  perich:
    chop_size_ms: 2000
  hatsopoulos:
    chop_size_ms: 2000
  flint:
    chop_size_ms: 2000
  schwartz:
    chop_size_ms: 2000
  chase:
    chop_size_ms: 2000
  limblab:
    chop_size_ms: 2000
  mayo:
    chop_size_ms: 0 # don't cut
  cst:
    chop_size_ms: 0 # don't cut

  max_channels: 320
  max_arrays: 2

  datasets: [] # To be overridden

  exclude_datasets:
  # DO NOT TOUCH - TEST SET FINE-TUNING
  # For fully separate consideration
  - falcon.* # Rouse M1 and 7D H1
  - odoherty_rtt.*
  - mayo.*
  - miller_Jango.* # for generalization eval

  # Training time eval curves (diagnostic, not formal eval)
  eval_ratio: 0.9
  eval_datasets:
  # Small amount of common held-out data for sanity checking runs with different data sources.
  # Implication is that these datasets are partially seen in training.
  - flint.*
  - chase_Rocky.* #

  tokenize_covariates: True
  sparse_constraints: True
  sparse_rewards: False
  behavior_dim: 17 # Force is here, include +1 for padding, etc. 16 emg signals, +1 for padding
  data_keys:
  - DataKey.spikes
  - DataKey.constraint
  - DataKey.bhvr_vel
  - DataKey.task_return
  - DataKey.bhvr_mask
train:
  autoscale_batch_size: true
  batch_size: 64 # Currently takes about 24G on 40G
  effective_batch_size: 16384 # about 200K trials in an 200H.
  patience: 100
  epochs: 400
cancel_if_run_exists: False # Rarely do we launch multiple PT in error
# PT should save last on top of best metrics.
save_last: True