tag: eeg

gpus: 4
num_nodes: 1
num_workers: 4
batch_size: 2

seed: 42 # Used in stable diffusion
resume: False
pretrained_checkpoint: null
load_state_dict: False
training: True
print_model: True
find_unused_parameters: True # This is set to True by default in PyTorch-Lightning
final_validate: True

finetune_pretrained: false
finetune_pretrained_path: null
freeze_decoder: false
model_checkpoint:
  save_last: True
  save_top_k: 1
  monitor: "val_psnr"
  mode: "max"

callbacks:
  lr_monitor:
    _target_: 'pytorch_lightning.callbacks.LearningRateMonitor'
    logging_interval: step
  progress_bar:
    _target_: 'pytorch_lightning.callbacks.TQDMProgressBar'
    refresh_rate: 10
  model_checkpoint:
    _target_: 'pytorch_lightning.callbacks.ModelCheckpoint'
    dirpath: '/cluster/work/cvl/eeg_foundation/Pretraining_Experiments/checkpoints/'
    filename: 'chkp-{epoch:02d}-{val_loss:.2f}'
  
# datasets:
#   demo_dataset:
#     _target_: 'datasets.custom_eeg_dataset.CustomEEGDataset'
#     data_path: "/cluster/work/cvl/eeg_foundation/demo_data/all_data.npy"
#     labels_path: "/cluster/work/cvl/eeg_foundation/demo_data/all_labels.npy"


io:
  base_output_path: "/srv/beegfs02/scratch/eeg_signal_analysis/data/eeg_demo_outputs/"
  version: 0

trainer:
  num_nodes: ${num_nodes}
  devices: ${gpus}
  strategy: auto
  max_epochs: 100
  max_steps: 500000
  benchmark: True
  check_val_every_n_epoch:  1
  num_sanity_val_steps: 2

optimizer:
  optim: 'AdamW'
  lr: 1e-4

defaults:
  - data_module: concat_data_module
  - task: base_task
  - scheduler: multi_step_lr
  - model: null 
  - model_head: mlp_classification_head
  - criterion: reconstruction_norm
