# @package _global_
tag: rtt_ft

model:
  task:
    task: ModelTask.heldout_decoding
    freeze_backbone: false
    metrics:
    - Metric.co_bps
    - Metric.block_co_bps
    outputs:
    - Output.heldout_logrates
  readin_strategy: EmbedStrat.token
  subject_embed_strategy: EmbedStrat.token
  lr_schedule: 'cosine_warmup'
  lr_init: 2e-5 # just a test for now
  lr_ramp_steps: 500 # because tiny batches
  lr_decay_steps: 5000
  # weight_decay: 0.01
  weight_decay: 0.03
  dropout: 0.4
dataset:
  bin_size_ms: 5
  datasets:
  - mc_rtt
  max_channels: 98
  max_arrays: 1
  meta_keys:
  - MetaKey.unique
  - MetaKey.session
  # - MetaKey.array
  - MetaKey.subject
  data_keys:
  - DataKey.spikes
  - DataKey.heldout_spikes
train:
  epochs: 50000
  batch_size: 32 # smaller batches during tuning
  patience: 500 # smaller batches

init_from_id: "rtt_cos-v2ovd8aq"