# lightning.pytorch==2.2.1
seed_everything: true
trainer:
  accelerator: auto
  strategy: ddp
  devices: 2
  num_nodes: 1
  precision: bf16-mixed
  logger:
    class_path: lightning.pytorch.loggers.WandbLogger
    init_args:
      project: phalar
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: valid_accuracy
        save_top_k: 5
        mode: max
        filename: checkpoint-{epoch}-{valid_accuracy:.4f}
  fast_dev_run: false
  max_epochs: 500
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 0
  log_every_n_steps: null
  enable_checkpointing: null
  enable_progress_bar: true
  enable_model_summary: null
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: true
  benchmark: null
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: true
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: /path/to/checkpoints
model:
  learning_rate_muon: 0.02
  learning_rate: 4e-3
  weight_decay: 1e-3
  proj_weight_decay: 1e-3
  embedding_dim: 512
  embedding_mode: RANDOM
  input_type: SINGLE_CHANNEL_SPECTROGRAM # DOUBLE_CHANNEL_HARMONIC_PERCUSSIVE
  dropout_p: 0.1
  comparison_method: BILINEAR_SIMILARITY # BILINEAR_SIMILARITY / COCOLA_SIMILARITY
  label_smoothing_targ: .9 # .9 > 0.168; .8 > .171; .7 > .163
  backbone_type: freqtimeseparable
  ft_sep_init_hidden_dim: 8
  ft_sep_depth: 5
  pool_type: spec_pool
  spec_pool_out_channels: 80
  spec_pool_components: 8
  spec_pool_norm: global
  spec_pool_complex_output: false
  spec_pool_center_padding: true
  input_freq_bins: 96 # 513 for STFT | 128 for MEL | 64 for HPSS | 96 for CQT
  do_cqt: True
data:
  root_dir: /path/to/datasets
  dataset: MIXED
  batch_size: 64
  chunk_duration_range: [2, 10] # Minimum time length is ~2, as we want 2 components for fft pooling
  chunk_duration_test: 5
  target_sample_rate: 16000
  generate_submixtures: true
  feature_extractor_type: RAW_WAVEFORM # STFT/CQT/MEL/HPSS_SPECTROGRAM | RAW_WAVEFORM
  feature_extraction_time: ONLINE
  augmentations:
    # pitch_shift:
    #   p: 0.5
    #   semitones: [-0.5, 0.5]
    # time_stretch:
    #   p: 0.5
    #   rate: [0.8, 1.2]
    random_gain:
      p: 1.0
      db: [-6, 6]
    white_noise:
      p: 0.5
      snr_db: [15, 25]
    pink_noise:
      p: 0.5
      snr_db: [15, 25]
    brown_noise:
      p: 0.3
      snr_db: [12, 20]
    band_limited_noise:
      p: 0.3
      snr_db: [12, 20]
      band: [100, 8000]
    transient_noise:
      p: 0.2
      num_bursts: [1, 4]
      burst_length: [10, 31]
      snr_db: [6, 12]
    # random_tone:
    #   p: 0.3
    #   freq_range: [5, 7800]
    #   snr_db: [10, 20]
    # freq_shift:
    #   p: 0.5
    #   bins: [-6, 7]
    gaussian_noise:
      p: 0.5
      std: [0.005, 0.05]
ckpt_path: null
