################################################################################
# Global variables
################################################################################

variables:
  run_name: 7b-Critique-SFT-merged-2222
  dataset_path: ../datasets/dataset_critic/merged
  micro_batch_size: 2 # Set to what fits your machine. Global batch size must be divisible by gpus * microbatch_size
  global_seed: 17
  max_seq_len: 8192
  max_audio_len: 1500

run_name: ${variables.run_name}

################################################################################
# Model
################################################################################
model:
  feedback_method: csft
  base_model_name_or_path: models/Qwen2-Audio-7B


################################################################################
# Training hparams
################################################################################

seed: ${variables.global_seed}
max_duration: 1ep

lm_weight: 1.0

global_train_batch_size: 16
device_train_microbatch_size: ${variables.micro_batch_size}
device_eval_batch_size: ${variables.micro_batch_size}

# precision: amp_bf16
precision: amp_fp16

optimizer:
  betas:
  - 0.9
  - 0.95
  eps: 1.0e-10
  lr: 5e-6
  name: decoupled_adamw
  weight_decay: 5e-6

scheduler:
  alpha_f: 0.01
  name: cosine_with_warmup
  t_warmup: 0.1dur


################################################################################
# Data
################################################################################

train_loader:
  dataset:
    max_seq_len: ${variables.max_seq_len}
    max_audio_len: ${variables.max_audio_len}
    shuffle: true
    remote: ${variables.dataset_path}
    split: train
  drop_last: true
  name: text
  num_workers: 8

eval_subset_num_batches: -1
eval_loader:
- dataset:
    max_seq_len: ${variables.max_seq_len}
    max_audio_len: ${variables.max_audio_len}
    shuffle: false
    remote: ${variables.dataset_path}
    split: val
  drop_last: true
  label: ultra
  name: text
  num_workers: 8


################################################################################
# Parallelism
################################################################################

fsdp_config:
  verbose: false
  mixed_precision: PURE
  limit_all_gathers: true
  sharding_strategy: FULL_SHARD
  activation_checkpointing: true
  sync_module_states: true

################################################################################
# Logging, callbacks, etc
################################################################################

# Uncomment to log to wandb
# loggers:
  # wandb:
  #   project: 

algorithms:
  gradient_clipping:
    clipping_threshold: 1
    clipping_type: norm

callbacks:
  lr_monitor: {}
  memory_monitor: {}
  speed_monitor:
    window_size: 10
  runtime_estimator: {}

################################################################################
# Misc
################################################################################

autoresume: true
dist_timeout: 14400 
console_log_interval: 50ba
progress_bar: false
python_log_level: debug
log_to_console: true

eval_first: true
eval_interval: 0.1dur

################################################################################
# Model saving / loading
################################################################################
save_folder: ckpts_critic/${variables.run_name}
save_interval: 1dur # Save twice per run