# CrissCross Word Classification Evaluation Configuration

# Model checkpoints
model:
  # Training mode: false = fine-tune from pretrained, true = train from scratch
  train_from_scratch: false

  # Path to pretrained CrissCrossTransformer checkpoint
  # (Used for loading weights when train_from_scratch=false, or just architecture when train_from_scratch=true)
  # criss_cross_checkpoint: "./checkpoints/criss-cross-multi-dataset-pretrain/dphy4ltd/last.ckpt"
  criss_cross_checkpoint: "./checkpoints/criss-cross-multi-dataset-pretrain-50Hz/c5z56xy8/last.ckpt" # CamCAN

  # Path to BioCodec tokenizer checkpoint
  tokenizer_checkpoint: "./brainstorm/neuro_tokenizers/biocodec_ckpt.pt"

  # Word embedding MLP configuration
  word_mlp:
    hidden_dim: 2048
    embed_dim: 1024  # T5-large embedding dimension
    dropout: 0.1

# Data settings
data:
  # Dataset type: "armeni" (default), "gwilliams", or "libribrain"
  dataset_type: "armeni"

  # Dataset root directory
  root: "/path/to/armeni2022"
  cache_dir: "./data/cache"

  # Subjects to include (all 3 subjects)
  subjects:
    - "sub-001"
    - "sub-002"
    - "sub-003"

  # Split mode: session-based or hashed
  # Set use_hashed_split to true for sentence-based hashed splitting
  # Set use_hashed_split to false for traditional session-based temporal split
  use_hashed_split: true

  # Hashed split parameters (used when use_hashed_split=true)
  split_ratios: [0.8, 0.1, 0.1]  # train, val, test
  all_sessions:  # Sessions to include (dataset skips missing subject/session combos)
    - "ses-001"
    - "ses-002"
    - "ses-003"
    - "ses-004"
    - "ses-005"
    - "ses-006"
    - "ses-007"
    - "ses-008"
    - "ses-009"
    - "ses-010"

  # Session-based temporal split (used when use_hashed_split=false)
  # Training on early sessions, validation on middle, test on late
  train_pct: 0.25
  train_sessions:
    - "ses-001"
    - "ses-002"
    - "ses-003"
    - "ses-004"
    - "ses-005"
    - "ses-006"
    - "ses-007"
    - "ses-008"

  val_sessions:
    # - "ses-001"
    - "ses-009"

  test_sessions:
    # - "ses-001"
    - "ses-010"

  tasks:
    - "compr"

  # Preprocessing parameters (must match CrissCross pretraining)
  l_freq: 0.1
  h_freq: 40.0
  target_sfreq: 50.0

  # Word window parameters
  segment_length: 150.0
  subsegment_duration: 3.0
  words_per_segment: 50
  window_onset_offset: -0.5

# T5 embedding settings
t5:
  model_name: "t5-large"
  layer: 12  # Which layer to extract embeddings from
  cache_dir: "./embeddings_cache"

# Training settings
training:
  batch_size: 1  # Effective batch size: 4 samples * 10 words = 40 words per batch
  num_epochs: 50

  # Resume from checkpoint (path to checkpoint_latest.pt or checkpoint_best.pt)
  resume_checkpoint: null

  # Learning rates for different training modes
  # Fine-tuning mode (train_from_scratch=false): use differential learning rates
  criss_cross_lr: 1.0e-5  # Small LR to preserve pretrained features
  word_mlp_lr: 1.0e-3     # Larger LR for learning from scratch

  # From-scratch mode (train_from_scratch=true): use same LR for all components
  from_scratch_lr: 1.0e-4  # Higher than fine-tuning, matches pretraining script

  # Regularization
  weight_decay: 1.0e-4
  gradient_clip_val: 1.0

  # Early stopping
  patience: 10
  min_delta: 0.001

  # DataLoader settings
  num_workers: 6
  pin_memory: true

# Loss settings (SigLIP)
loss:
  norm_kind: "xy"  # Normalize both predictions and targets
  temperature: true
  bias: true
  reduction: "sum"

# Evaluation settings
evaluation:
  # Retrieval set sizes: accuracy is computed for samples with labels in these top-K words
  # e.g., [50, 250] means compute accuracy when retrieving from top-50 and top-250 words
  retrieval_set_sizes: [50, 250]
  # K value for top-k accuracy (e.g., k=10 for top-10 accuracy)
  k: 10

# Logging settings
logging:
  wandb_project: "brainstorm-eval"
  experiment_name: "criss-cross-word-classification-1ses-50Hz"
  log_every_n_steps: 10
  save_dir: "./logs/word_classification"

# System settings
device: "cuda"
seed: 42
