data_module:
  _target_: data_module.concatenated_eeg_data_module.ConcatenatedEEGDataModule
  name: emg
  cfg:
    num_workers: ${num_workers}
    batch_size: ${batch_size}
  train: ${datasets}
  val: null
  test: null
task:
  _target_: tasks.wave_reconstruction.MaskedWaveletPretrainingTask
scheduler:
  _target_: schedulers.cosine_lr.CosineLRSchedulerWrapper
  trainer: ${trainer}
  warmup_epochs: 10
  min_lr: 2.5e-07
  warmup_lr_init: 2.5e-07
  t_in_epochs: false
model:
  _target_: models.waveMiM.ChannelVisionTransformer
  in_ch: 16
  max_level: 4
  wave_kernel_size: 24
  timesteps: 1000
  patch_kernel:
  - 1
  - 64
  patch_stride:
  - 1
  - 64
  embed_dim: 512
  depth: 8
  num_heads: 8
  mlp_ratio: 4.0
  drop_path: 0.1
  attention_type: wavelet_enhanced
  masking_ratio: 0.7
  hw_square_kernel: 3
  hw_band_kernel: 15
  reduced_dim: 24
  ffn_kernel_size: 5
  importance_ratio: 0.7
model_head:
  _target_: models.modules.wave_decoder.MAEDecoder
  embed_dim: ${model.embed_dim}
  decoder_embed_dim: 512
  decoder_output_dim: 64
  decoder_num_heads: 8
  decoder_depth: 8
  mlp_ratio: 4.0
  norm_layer: torch.nn.LayerNorm
  attention_type: fft
  drop_path: 0.0
  max_seq_len: 2048
  patch_size: ${model.patch_kernel}
  in_height: ${model.reduced_dim}
  in_width: ${model.timesteps}
criterion:
  _target_: criterion.reconstruction_norm.ReconstructionNorm
  patch_size:
  - 1
  - 64
  alpha: 0.7
  loss_type: smooth_l1
  using_spectrogram: true
  beta: 0.2
  gamma: 0.1
  save_dir: /capstor/scratch/cscs/cyanlong/ECG_Pretrain/waveECG_pretraining/figures
tag: apr_13
gpus: 4
num_nodes: 2
num_workers: 16
batch_size: 64
seed: 42
resume: false
pretrained_checkpoint: null
load_state_dict: false
training: true
print_model: true
find_unused_parameters: true
final_validate: true
finetune_pretrained: false
finetune_pretrained_path: null
freeze_decoder: false
model_checkpoint:
  save_last: true
  save_top_k: 1
  monitor: val_psnr
  mode: max
callbacks:
  lr_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step
  progress_bar:
    _target_: pytorch_lightning.callbacks.TQDMProgressBar
    refresh_rate: 10
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    dirpath: /capstor/scratch/cscs/cyanlong/ECG_Pretrain/waveECG_pretraining/checkpoints/
    filename: ${experiment_name}-loss${criterion.loss_type}_alpha${criterion.alpha}_attention${model.attention_type}-{epoch:02d}-{val_loss:.4f}
    monitor: val_loss
    mode: min
    save_top_k: 1
    save_last: false
    every_n_epochs: 1
io:
  base_output_path: /capstor/scratch/cscs/cyanlong/ECG_Pretrain/waveECG_pretraining/tensorboard_logs
  version: ${experiment_name}-loss${criterion.loss_type}_alpha${criterion.alpha}_beta${criterion.beta}_gamma${criterion.gamma}_attention${model.attention_type}
trainer:
  num_nodes: ${num_nodes}
  devices: -1
  strategy: ddp
  max_epochs: 50
  max_steps: 500000
  benchmark: true
  check_val_every_n_epoch: 2
  num_sanity_val_steps: 2
  accelerator: gpu
  accumulate_grad_batches: 64
optimizer:
  optim: AdamW
  lr: 0.00125
  betas:
  - 0.9
  - 0.98
  weight_decay: 0.05
experiment_name: waveECG_pretraining
final_test: false
datasets:
  nea:
    _target_: datasets.emg_pretrain_dataset.EMGPretrainDataset
    data_path: /capstor/scratch/cscs/cyanlong/nea.h5
    transform: true
