# pytorch_lightning==1.8.6
seed_everything: 4444

data:
  class_path: rfwave.dataset.VocosDataModule
  init_args:
    train_params:
      filelist_path: wav_filelist.train
      sampling_rate: 24000
      num_samples: 32512
      batch_size: 64
      num_workers: 8
      cache: True

    val_params:
      filelist_path: wav_filelist.valid
      sampling_rate: 24000
      num_samples: 65280
      batch_size: 16
      num_workers: 4
      cache: True

model:
  class_path: rfwave.experiment_reflow_subband_vq.VocosExp
  init_args:
    sample_rate: 24000
    feature_loss: False
    wave: True
    num_bands: 8
    guidance_scale: 1.
    p_uncond: 0.1
    initial_learning_rate: 2e-4
    num_warmup_steps: 20_000 # Optimizers warmup steps

    feature_extractor:
      class_path: rfwave.feature_extractors.MelSpectrogramFeatures
      init_args:
        sample_rate: 24000
        n_fft: 1024
        hop_length: 256
        n_mels: 100
        padding: center

    backbone:
      class_path: rfwave.models.VocosRFBackbone
      init_args:
        input_channels: 100
        output_channels: 160
        dim: 512
        intermediate_dim: 1536
        num_layers: 8
        num_bands: 8
        prev_cond: False
        encodec_num_embeddings: null

    head:
      class_path: rfwave.heads.RFSTFTHead
      init_args:
        dim: 512
        n_fft: 1024
        hop_length: 256
        padding: center

    quantizer:
      class_path: rfwave.quantizer.Quantizer
      init_args:
        feat_dim: 100
        dim: 512
        reduce: 2
        num_layers: 12
        levels: [4, 4, 4, 4, 4]
        num_quantizers: 4

trainer:
  check_val_every_n_epoch: 10
  logger:
    class_path: pytorch_lightning.loggers.WandbLogger
    init_args:
      project: rfwave
      save_dir: logs-rfwave/
      name: rfwave
  callbacks:
    - class_path: pytorch_lightning.callbacks.LearningRateMonitor
    - class_path: pytorch_lightning.callbacks.ModelSummary
      init_args:
        max_depth: 2
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: val_loss
        filename: rfwave_checkpoint_{epoch}_{step}_{val_loss:.4f}
        save_top_k: 3
        save_last: true
    - class_path: rfwave.helpers.GradNormCallback

  # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
  # This equals to 1M steps per generator and 1M per discriminator
  max_steps: 1_000_000
  # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
  limit_val_batches: 10
  accelerator: gpu
  devices: [0]
  strategy: auto
  log_every_n_steps: 1000
