defaults:
  - basic
  - model: flow_matching_base
  - data@data_dict: train_val
  - loss@loss_fn: weighted_sum
  - data/task_sampling@task_sampling_weights: time_align_balanced
  - _self_

exp_name: wave_dit_fm
exp_dir: experiments/${exp_name}/
logging_file: ${exp_dir}/train.log

train_dataloader:
  _target_: torch.utils.data.DataLoader
  dataset:
    _target_: data_module.dataset.TaskGroupedAudioGenConcatDataset
    datasets: ${data_dict.train_data_list}
  sampler:
    _target_: data_module.sampler.TaskIteratingSampler
    shuffle: true
    task_sampling_weights: ${task_sampling_weights}
  batch_size: 24 # per device batch size
  num_workers: 12
  collate_fn:
    _target_: data_module.collate_function.PaddingCollate
    pad_keys: ["waveform", "duration", "instruction"]
    torchify_keys: ["is_time_aligned"]

val_dataloader:
  _target_: torch.utils.data.DataLoader
  dataset:
    _target_: data_module.dataset.AudioGenConcatDataset
    datasets: ${data_dict.val_data_list}
  batch_size: 24 # per device batch size
  shuffle: false
  num_workers: 12
  collate_fn:
    _target_: data_module.collate_function.PaddingCollate
    pad_keys: ["waveform",  "duration", "instruction"]
    torchify_keys: ["is_time_aligned"]

warmup_params:
  warmup_steps: 10000
  warmup_epochs: Null
  epoch_length: ${epoch_length}

gradient_accumulation_steps: 1

optimizer:
  _target_: torch.optim.AdamW
  lr: !!float 5e-5
  weight_decay: 0.01

lr_scheduler:
  _target_: "transformers.get_scheduler"
  # name: "linear"
  name: "constant_with_warmup"

epochs: 150
epoch_length: 2000

trainer:
  _target_: audio_generation_trainer.MultiTaskAudioGenerationTrainer
  project_dir: ${exp_dir}
  logging_file: ${logging_file}
  gradient_accumulation_steps: ${gradient_accumulation_steps}
  max_grad_norm: 1.0
  epochs: ${epochs}
  epoch_length: ${epoch_length}
  save_last_k: 2
  permanent_save_every_n_steps: 100000
  early_stop: Null
  logging_config:
    _target_: trainer.LoggingConfig
    report_to: swanlab
    project: uniflow_audio_experiments
    # project: runs
    save_dir: ${exp_dir}
    name: ${exp_name}
    resume_id: Null
    workspace: anonymous
  metric_monitor:
    _target_: trainer.MetricMonitor
    metric_name: loss
    mode: min