seed_everything: 3407

data:
  class_path: decoder.dataset.VocosDataModule
  init_args:
    train_params:
      filelist_path: /path/to/dir/libritts_merge.txt
      sampling_rate: 24000
      num_samples: 72000
      batch_size: 16  #40 # 20
      num_workers: 8 #8

    val_params:
      filelist_path: /path/to/dir/audio_files_list.txt
      sampling_rate: 24000
      num_samples: 72000
      batch_size: 5   # 10
      num_workers: 8 #8

model:
  class_path: decoder.experiment.WavTokenizer
  init_args:
    sample_rate: 24000
    initial_learning_rate: 8e-5
    mel_loss_coeff: 45
    mrd_loss_coeff: 1.0
    num_warmup_steps: 0 # Optimizers warmup steps
    pretrain_mel_steps: 0  # 0 means GAN objective from the first iteration

    # automatic evaluation
    evaluate_utmos: true
    evaluate_pesq: true
    evaluate_stoi: true
    evaluate_sdr: true
    evaluate_periodicty: true

    resume: true
    resume_config: path/to/dir/config.yaml
    resume_model: path/to/dir/xxxx.ckpt

    feature_extractor:
      class_path: decoder.feature_extractors.EncodecFeatures
      init_args:
        encodec_model: encodec_24khz
        bandwidths: [6.6, 6.6, 6.6, 6.6]
        train_codebooks: true
        num_quantizers: 1  
        dowmsamples: [6, 5, 5, 3]
        vq_kmeans: 200
        vq_type: hexagon #hexagon, rhombic, rectangle
        codebook_dim: [9, 9, 7, 7, 7, 7]

    backbone:
      class_path: decoder.models.VocosBackbone
      init_args:
        input_channels: 512
        dim: 768
        intermediate_dim: 2304
        num_layers: 12
        adanorm_num_embeddings: 4  

    head:
      class_path: decoder.heads.ISTFTHead
      init_args:
        dim: 768
        n_fft: 1800 
        hop_length: 450
        padding: same

trainer:
  logger:
    class_path: pytorch_lightning.loggers.TensorBoardLogger
    init_args:
      save_dir: path/to/dir
  callbacks:
    - class_path: pytorch_lightning.callbacks.LearningRateMonitor
    - class_path: pytorch_lightning.callbacks.ModelSummary
      init_args:
        max_depth: 2
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: val_loss
        filename: Q2D2_checkpoint_{epoch}_{step}_{val_loss:.4f}
        save_top_k: 10
        save_last: true
    - class_path: decoder.helpers.GradNormCallback

  # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
  # This equals to 1M steps per generator and 1M per discriminator
  max_steps: 2000000
  # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
  limit_val_batches: 100
  limit_train_batches: 1000
  accelerator: gpu
  strategy: ddp
  devices: [0,1]
  num_nodes: 1
  log_every_n_steps: 1000
