data_module:
  _target_: data_module.concatenated_eeg_data_module.ConcatenatedEEGDataModule
  name: emg
  cfg:
    num_workers: ${num_workers}
    batch_size: ${batch_size}
  train: ${datasets}
  val: null
  test: null
task:
  _target_: tasks.wave_reconstruction.MaskedWaveletPretrainingTask
scheduler:
  _target_: schedulers.cosine_lr.CosineLRSchedulerWrapper
  trainer: ${trainer}
  warmup_epochs: 2
  min_lr: 2.5e-07
  warmup_lr_init: 2.5e-07
  t_in_epochs: false
model:
  _target_: models.waveecg.ChannelVisionTransformer
  in_ch: 16
  max_level: 3
  wave_kernel_size: 16
  wavelet_names:
  - db4
  - sym4
  - bior3.5
  - coif3
  - dmey
  use_separate_channel: true
  ffn_ratio: 4.0
  ffn_kernel_size: 5
  ffn_drop: 0.1
  hw_square_kernel: 3
  hw_band_kernel: 15
  reduced_dim: 32
  timesteps: 1024
  patch_size:
  - 1
  - 64
  embed_dim: 384
  depth: 8
  num_heads: 12
  mlp_ratio: 4
  drop_path: 0.1
  attention_type: wavelet_enhanced
  masking_ratio: 0.7
  importance_ratio: 0.6
  use_masking: true
  max_seq_len: 4096
model_head:
  _target_: models.model_heads.wave_decoder.MAEDecoder
  embed_dim: ${model.embed_dim}
  decoder_embed_dim: 256
  decoder_output_dim: 64
  decoder_num_heads: 8
  decoder_depth: 8
  mlp_ratio: 4.0
  attention_type: fft
  drop_path: 0.0
  max_seq_len: 2048
  patch_size: ${model.patch_size}
  in_height: ${model.reduced_dim}
  in_width: ${model.timesteps}
criterion:
  _target_: criterion.reconstruction_norm.ReconstructionNorm
  patch_size:
  - 1
  - 64
  alpha: 0.1
  loss_type: smooth_l1
  using_spectrogram: true
tag: apr_25
gpus: 4
num_nodes: 4
num_workers: 8
batch_size: 32
seed: 42
resume: false
pretrained_checkpoint: null
load_state_dict: false
training: true
print_model: true
find_unused_parameters: true
final_validate: true
finetune_pretrained: false
finetune_pretrained_path: null
freeze_decoder: false
model_checkpoint:
  save_last: true
  save_top_k: 1
  monitor: val_psnr
  mode: max
callbacks:
  lr_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step
  progress_bar:
    _target_: pytorch_lightning.callbacks.TQDMProgressBar
    refresh_rate: 10
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    dirpath: /leonardo_work/CNHPC_1526560/yanlchen/EMG_Pretrain/checkpoints_base/
    filename: ${experiment_name}-loss${criterion.loss_type}_alpha${criterion.alpha}_attention${model.attention_type}-{epoch:02d}-{val_loss:.4f}
    monitor: val_loss
    mode: min
    save_top_k: 1
    save_last: false
    every_n_epochs: 1
io:
  base_output_path: /leonardo_work/CNHPC_1526560/yanlchen/EMG_Pretrain/tensorboard_logs_base
  version: ${experiment_name}-loss${criterion.loss_type}_alpha${criterion.alpha}_attention${model.attention_type}
trainer:
  num_nodes: ${num_nodes}
  devices: -1
  strategy: ddp
  max_epochs: 10
  max_steps: 500000
  benchmark: true
  check_val_every_n_epoch: 2
  num_sanity_val_steps: 2
  accelerator: gpu
  accumulate_grad_batches: 32
optimizer:
  optim: AdamW
  lr: 0.0005
  betas:
  - 0.9
  - 0.98
  weight_decay: 0.01
experiment_name: waveEMG_base_pretraining
final_test: false
datasets:
  emg2pose:
    _target_: datasets.emg_pretrain_dataset.EMGPretrainDataset
    file_path: /leonardo_scratch/large/userexternal/ychen003/EMG_Pretrain/emg2pose2.h5
    transform: false
  db6a:
    _target_: datasets.emg_pretrain_dataset.EMGPretrainDataset
    file_path: /leonardo_scratch/large/userexternal/ychen003/EMG_Pretrain/db6a.h5
    transform: false
  db6b:
    _target_: datasets.emg_pretrain_dataset.EMGPretrainDataset
    file_path: /leonardo_scratch/large/userexternal/ychen003/EMG_Pretrain/db6b.h5
    transform: false
  db7:
    _target_: datasets.emg_pretrain_dataset.EMGPretrainDataset
    file_path: /leonardo_scratch/large/userexternal/ychen003/EMG_Pretrain/db7.h5
    transform: false
  db8:
    _target_: datasets.emg_pretrain_dataset.EMGPretrainDataset
    file_path: /leonardo_scratch/large/userexternal/ychen003/EMG_Pretrain/db8.h5
    transform: false
