# @package _global_

# Configuring Experiments
# https://hydra.cc/docs/patterns/configuring_experiments/
# Fine-tuning the PhysioWave_ecg pretrained small model on the PTB‑XL dataset for a five‑class classification task.
tag: apr_28
experiment_name: waveECG_small_fintune_ptbxl

num_nodes: 1
num_workers: 8
batch_size: 32
find_unused_parameters: True # For optimal performance. If True and there are no unused params, performance is sub-optimal.




pretrained_checkpoint: '/ECG_checkpoints/ECG_small.ckpt'
final_test: True
defaults:
  - override /model: waver
  - override /model_head: mlp_classification_head
  - override /scheduler: cosine_lr
  - override /criterion: ce_criterion  
  - override /task: classification_task


callbacks:
  lr_monitor:
    _target_: 'pytorch_lightning.callbacks.LearningRateMonitor'
    logging_interval: step
  progress_bar:
    _target_: 'pytorch_lightning.callbacks.TQDMProgressBar'
    refresh_rate: 10
  model_checkpoint:
    dirpath: '/ECG_Finetune/checkpoints/${experiment_name}'
    filename: '${experiment_name}-{epoch:02d}-{val_loss:.4f}'
    every_n_epochs: 1
    save_top_k: 5
    monitor: 'val_loss'
    save_last: True

datasets: null

data_module:
  _target_: 'data_module.eeg_data_module.EEGDataModule'
  cfg:
    num_workers: ${num_workers}
    batch_size: ${batch_size}
  train:
    _target_: 'datasets.emg_finetune_dataset.FinetuneDataset'
    file_path: "/ECG_Finetune/ptbxl/train.h5"
    transform: false
  val:
    _target_: 'datasets.emg_finetune_dataset.FinetuneDataset'
    file_path: "/ECG_Finetune/ptbxl/val.h5"
    transform: false
  test:
    _target_: 'datasets.emg_finetune_dataset.FinetuneDataset'
    file_path: "/ECG_Finetune/ptbxl/test.h5"
    transform: false

model:
  ###############################################################################
  # 1) Wavelet‑based front‑end
  ###############################################################################
  in_ch: 12                 # Number of input channels (e.g. 12‑lead ECG)
  max_level: 4              # Wavelet decomposition depth for multi‑scale features
  wave_kernel_size: 24      # 1‑D kernel length to capture time‑frequency patterns
  wavelet_names: ['db6', 'sym4', 'bior3.5', 'coif3']  # Bases suited to non‑stationary signals
  use_separate_channel: true  # Depthwise filtering: one wavelet filter per channel

  ###############################################################################
  # 2) FFN & CrossScale‑CAFFN
  ###############################################################################
  ffn_ratio: 4.0            # Expansion ratio in the feed‑forward network
  ffn_kernel_size: 5        # Depthwise‑conv kernel in CAFFN (captures motor‑unit patterns)
  ffn_drop: 0.1             # Dropout rate to mitigate over‑fitting

  ###############################################################################
  # 3) Patch parameters
  ###############################################################################
  patch_width: 64           # Temporal width of each patch

  ###############################################################################
  # 4) Transformer & masking
  ###############################################################################
  embed_dim: 256            # Token embedding dimension
  depth: 6                 # Number of Transformer blocks
  num_heads: 8             # Attention heads
  mlp_ratio: 4.0            # MLP expansion ratio
  drop_path: 0.15           # Stochastic‑depth rate
  attention_type: "default" # Attention variant used in CustomAttentionBlock
  masking_ratio: 0.7        # MAE‑style mask fraction during pre‑training
  importance_ratio: 0.6     # Weight of spectral energy in mask scoring
  use_masking: true         # Enable masking for self‑supervised training

  ###############################################################################
  # 5) Rotary position encoding
  ###############################################################################
  max_seq_len: 2048         # Maximum sequence length handled by RoPE



model_head:
  hidden_layers: [512, 256]              
  num_classes: 5                       
  drop: 0.0                           


task:
  freeze_backbone: false
  layerwise_lr_decay: 0.9
  noise_level: 0.0 
  augment_prob: 0.0
  transform: null
  
optimizer:
  optim: 'AdamW'
  lr: 5e-4
  betas: [0.9, 0.98]
  weight_decay: 0.01


scheduler:
  trainer: ${trainer}
  warmup_epochs: 5
  min_lr: 5e-6
  warmup_lr_init: 5e-7


io:
  base_output_path: "/ECG_Finetune_small/logs_ptbxl"
  version: ${experiment_name}


trainer:
  accelerator: gpu
  num_nodes: ${num_nodes}
  devices: -1
  max_epochs: 30
  strategy: ddp
  accumulate_grad_batches: 1
  check_val_every_n_epoch: 1


