mode: quantized_emb
sample_rate: 16000
target_sec: 10
num_outputs: 6
probe_ckpt_dir: ???
seed: 666
codec_name: speechtokenizer # 需要更改
task: multiclass
save_result: null

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: ???
  precision: 32
  max_epochs: 100
  limit_val_batches: 10
  log_every_n_steps: 5
  val_check_interval: 1.0

data:
  _target_: codec_evaluation.probe.dataset.VocalSound_dataset.VocalSound_dataset.VocalSoundDataModule
  dataset_args:
    sample_rate: ${sample_rate}
    target_sec: ${target_sec}
  train_audio_dir: ???
  val_audio_dir: ???
  test_audio_dir: ???
  train_batch_size: 32
  val_batch_size: 2
  test_batch_size: 32
  train_num_workers: 8
  val_num_workers: 4
  test_num_workers: 4


model:
  _target_: codec_evaluation.probe.model.lit_prober.Prober
  codec_name: ${codec_name}
  sample_rate: ${sample_rate}
  mode: ${mode}
  task: ${task}
  num_outputs: ${num_outputs}
  probe_model_builder:
    _target_: codec_evaluation.probe.model.multiclass_model.MulticlassProber
    _partial_: true
    num_outputs: ${num_outputs}
    drop_out: 0.1
    channel_reduction: 16
    padding: 1
    kernel_size: 3
    stride: 1
  target_sec: ${target_sec}
  model_ckpt_dir: ???
  
  optimizer_builder:
      _target_: torch.optim.AdamW
      _partial_: true
      lr: 1e-4
      betas: [0.8, 0.99]
      eps: 1e-5
      weight_decay: 0.08
    
  lr_scheduler_builder:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: codec_evaluation.utils.schedule.get_cosine_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 10
      num_training_steps: 10000
      final_lr_ratio: 0.2

callbacks:
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: step

  rich_progress_bar:
    _target_: pytorch_lightning.callbacks.RichProgressBar

  model_summary:
    _target_: pytorch_lightning.callbacks.ModelSummary
    max_depth: 1

  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val_loss
    dirpath: ${probe_ckpt_dir}
    every_n_epochs: 1
    mode: min
    save_top_k: 1
    save_last: False
    filename: ${codec_name}_${mode}_{epoch}-{val_loss:.4f}
    verbose: True

tensorboard:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: ???
  name: ${codec_name}_${mode}
  log_graph: true