seed_everything: 4444

data:
  class_path: rfwave.dataset.VocosDataModule
  init_args:
    train_params:
      filelist_path: wav_filelist.train
      sampling_rate: 24000
      num_samples: 48000
      batch_size: 32
      num_workers: 8
      cache: True

    val_params:
      filelist_path: wav_filelist.valid
      sampling_rate: 24000
      num_samples: 72000
      batch_size: 16
      num_workers: 4
      cache: True

model:
  class_path: rfwave.experiment_reflow_subband.VocosEncodecExp
  init_args:
    sample_rate: 24000
    feature_loss: False
    wave: True
    num_bands: 8
    guidance_scale: 1.
    p_uncond: 0.1
    initial_learning_rate: 2e-4
    num_warmup_steps: 20_000 # Optimizers warmup steps

    feature_extractor:
      class_path: rfwave.feature_extractors.EncodecFeatures
      init_args:
        encodec_model: encodec_24khz
        bandwidths: [1.5, 3.0, 6.0, 12.0]
        train_codebooks: false

    backbone:
      class_path: rfwave.models.VocosRFBackbone
      init_args:
        input_channels: 128
        output_channels: 192
        dim: 512
        intermediate_dim: 1536
        num_layers: 8
        num_bands: 8
        prev_cond: False
        encodec_num_embeddings: 4

    head:
      class_path: rfwave.heads.RFSTFTHead
      init_args:
        dim: 512
        n_fft: 1280
        hop_length: 320
        padding: same

trainer:
  check_val_every_n_epoch: 10
  logger:
    class_path: pytorch_lightning.loggers.WandbLogger
    init_args:
      project: rfwave
      save_dir: logs-rfwave-encodec
      name: rfwave-encodec
  callbacks:
    - class_path: pytorch_lightning.callbacks.LearningRateMonitor
    - class_path: pytorch_lightning.callbacks.ModelSummary
      init_args:
        max_depth: 2
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: val_loss
        filename: rfwave_checkpoint_{epoch}_{step}_{val_loss:.4f}
        save_top_k: 3
        save_last: true
    - class_path: rfwave.helpers.GradNormCallback

  # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
  # This equals to 1M steps per generator and 1M per discriminator
  max_steps: 1_000_000
  # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
  limit_val_batches: 10
  accelerator: gpu
  devices: [0]
  strategy: auto
  log_every_n_steps: 1000
