mode: encode
sample_rate: 48000
probe_ckpt_dir: ???
seed: 666
codec_name: xcodec

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: ???
  precision: 32
  max_epochs: 10
  limit_val_batches: 5
  log_every_n_steps: 20
  val_check_interval: 1.0

data:
  _target_: codec_evaluation.probe.dataset.Common_Voice_dataset.Common_voice_dataset.Common_voice_module
  target_samplerate: ${sample_rate}
  train_audio_dir: ???
  val_audio_dir: ???
  test_audio_dir: ???
  base_audio_dir: /root/path/for/audio
  train_batch_size: 4
  val_batch_size: 4
  test_batch_size: 4
  train_num_workers: 4
  val_num_workers: 1
  test_num_workers: 1

model:
  _target_: codec_evaluation.probe.model.ctc_lit_prober.CodecCTCProbe
  codec_name: ${codec_name}
  sample_rate: ${sample_rate}
  mode: ${mode}
  tokenizer:
    _target_: transformers.Speech2TextProcessor.from_pretrained
    pretrained_model_name_or_path: ???
  probe_model_builder:
    _target_: codec_evaluation.probe.model.ctc_model.Ctc_Probe
    _partial_: true
    vocab_size: 10000
    codec_vocab_size: 1024
    dropout: 0.1
    lm_head_nums: 8
    conformer_depth: 3
    conformer_heads: 8
  model_ckpt_dir: ???
  teacher_ckpt_path: /codec_ckpt/path/for/xcodec/hubert_base_general_audio

  optimizer_builder:
      _target_: torch.optim.AdamW
      _partial_: true
      lr: 1e-4
      betas: [0.8, 0.99]
      eps: 1e-5
      weight_decay: 0.08

  lr_scheduler_builder:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: codec_evaluation.utils.schedule.get_cosine_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 200
      num_training_steps: 4000
      final_lr_ratio: 0.2

callbacks:
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step

  rich_progress_bar:
    _target_: pytorch_lightning.callbacks.RichProgressBar

  model_summary:
    _target_: pytorch_lightning.callbacks.ModelSummary
    max_depth: 1

  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val_loss
    dirpath: ${probe_ckpt_dir}
    every_n_epochs: 1
    mode: min
    save_top_k: 1
    filename: ${codec_name}_${mode}_{epoch}-{val_loss:.4f}
    verbose: True

tensorboard:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: ???
  name: ${codec_name}_${mode}
  log_graph: true