# @package _global_

# Configuring Experiments
# https://hydra.cc/docs/patterns/configuring_experiments/
#Pretraining the PhysioWave_ecg large model using extensive, large‑scale ECG datasets.
tag: apr_27
experiment_name: waveECG_large_pretraining

num_nodes: 4
num_workers: 8
batch_size: 32
find_unused_parameters: True # For optimal performance. If True and there are no unused params, performance is sub-optimal.


model_checkpoint:
  save_last: True
  save_top_k: 1
  monitor: "val_psnr"
  mode: "max"

pretrained_checkpoint: null
final_test: false
defaults:
  - override /data_module: concat_data_module
  - override /model: waver
  - override /model_head: wave_decoder
  - override /scheduler: cosine_lr
  - override /criterion: reconstruction_norm  
  - override /task: wave_reconstruction


callbacks:
  lr_monitor:
    _target_: 'pytorch_lightning.callbacks.LearningRateMonitor'
    logging_interval: step
  progress_bar:
    _target_: 'pytorch_lightning.callbacks.TQDMProgressBar'
    refresh_rate: 10
  model_checkpoint:
    dirpath: '/leonardo_work/CNHPC_1526560/yanlchen/ECG_Pretrain/checkpoints_large/'
    filename: '${experiment_name}-loss${criterion.loss_type}_alpha${criterion.alpha}-{epoch:02d}-{val_loss:.4f}'
    monitor: 'val_loss'
    mode: 'min'
    save_top_k: 1
    save_last: False
    every_n_epochs: 1

datasets:
  NEA:
    _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
    file_path: "/ECG_Pretrain/nea.h5"
    transform: false
  Georgia:
    _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
    file_path: "/ECG_Pretrain/georgia.h5"
    transform: false
  Medal:
    _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
    file_path: "/ECG_Pretrain/medalcare.h5"
    transform: false
  Code15:
    _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
    file_path: "/ECG_Pretrain/code15.h5"
    transform: false
  MIMIC:
    _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
    file_path: "/ECG_Pretrain/mimic.h5"
    transform: false

model:
  ###############################################################################
  # 1) Wavelet‑based front‑end
  ###############################################################################
  in_ch: 12                 # Number of input channels (e.g. 12‑lead ECG)
  max_level: 4              # Wavelet decomposition depth for multi‑scale features
  wave_kernel_size: 24      # 1‑D kernel length to capture time‑frequency patterns
  wavelet_names: ['db6', 'sym4', 'bior3.5', 'coif3']  # Bases suited to non‑stationary signals
  use_separate_channel: true  # Depthwise filtering: one wavelet filter per channel

  ###############################################################################
  # 2) FFN & CrossScale‑CAFFN
  ###############################################################################
  ffn_ratio: 4.0            # Expansion ratio in the feed‑forward network
  ffn_kernel_size: 5        # Depthwise‑conv kernel in CAFFN (captures motor‑unit patterns)
  ffn_drop: 0.1             # Dropout rate to mitigate over‑fitting

  ###############################################################################
  # 3) Patch parameters
  ###############################################################################
  patch_width: 64           # Temporal width of each patch

  ###############################################################################
  # 4) Transformer & masking
  ###############################################################################
  embed_dim: 512            # Token embedding dimension
  depth: 12                 # Number of Transformer blocks
  num_heads: 16             # Attention heads
  mlp_ratio: 4.0            # MLP expansion ratio
  drop_path: 0.15           # Stochastic‑depth rate
  attention_type: "default" # Attention variant used in CustomAttentionBlock
  masking_ratio: 0.7        # MAE‑style mask fraction during pre‑training
  importance_ratio: 0.6     # Weight of spectral energy in mask scoring
  use_masking: true         # Enable masking for self‑supervised training

  ###############################################################################
  # 5) Rotary position encoding
  ###############################################################################
  max_seq_len: 2048         # Maximum sequence length handled by RoPE

model_head:
  # Decoder settings
  decoder_embed_dim: 256    # Internal embedding size of the decoder
  decoder_num_heads: 8      # Attention heads in the decoder
  decoder_depth: 8          # Number of decoder Transformer blocks
  mlp_ratio: 4.0            # Feed‑forward expansion ratio
  attention_type: 'default' # Attention variant (e.g. Flash‑attention compatible)
  drop_path: 0.0            # Stochastic‑depth rate in the decoder

  # Patch setting (must match encoder)
  patch_width: 64           # Patch width identical to the encoder

criterion:
  save_dir: "/ECG_Pretrain/Figures"
  alpha: 0.1
  loss_type: "smooth_l1"
  using_spectrogram: True
  
optimizer:
  optim: 'AdamW'
  lr: 5e-5
  betas: [0.9, 0.98]
  weight_decay: 0.01


scheduler:
  trainer: ${trainer}
  min_lr: 1e-6 # minimum LR for the cosine scheduler
  warmup_lr_init: 2e-7
  warmup_epochs: 5


io:
  base_output_path: "/ECG_Pretrain/tensorboard_logs_large"
  version: ${experiment_name}-loss${criterion.loss_type}_alpha${criterion.alpha}


trainer:
  accelerator: gpu
  num_nodes: ${num_nodes}
  devices: -1
  max_epochs: 50
  strategy: ddp
  accumulate_grad_batches: 32
  check_val_every_n_epoch: 5



