args:
  # General arguments
  train_algo: 'SFT'

  # Model arguments
  architecture: 'llama'
  rope_theta: !!float 1e5
  partial_rotary_factor: 1.0
  hidden_size: 384
  intermediate_size: 1536
  num_attention_heads: 6
  num_layers: 6
  max_position_embeddings: 1024
  dropout: 0.0

  # Data arguments
  use_iterable_dataset: False
  num_train: 10_000_000
  num_eval: 1024
  num_workers: 8
  add_special_tokens: 0
  padding_side: 'right'
  mask_prompt: false

  # Train data configuration
  train_data:
    - op: 'reverse_add_ICL'
      frac: 1.0
      kwargs:
        la: [1, 17]
        lb: [1, 17]

  # Eval data configuration
  eval_data:
    - op: 'reverse_add_ICL'
      frac: 1.0
      kwargs:
        la: [5, 35, 4]
        lb: null
      eval_keys: ['la']
      tied_keys: []

training_args:
  # Training argumnents
  resume_from_checkpoint: False
  save_total_limit: 1
  run_name: ""
  output_dir: out
  do_train: True
  do_eval: True
  max_steps: 10000
  learning_rate: !!float 1e-3
  lr_scheduler_type: warmup_stable_decay
  lr_scheduler_kwargs:
    num_stable_steps: 4000
    num_decay_steps: 500
    min_lr_ratio: 0.01
  warmup_ratio: 0.1
  weight_decay: 0.01
  adam_beta2: 0.98
  adam_epsilon: !!float 1e-12
  logging_steps: 50
  eval_strategy: steps
  eval_steps: 500
  remove_unused_columns: False
  eval_on_start: True
  per_device_train_batch_size: 128
  per_device_eval_batch_size: 256
  gradient_accumulation_steps: 4
  include_inputs_for_metrics: True
  save_steps: 500
  torch_compile: False
  bf16: True
  tf32: True
  dispatch_batches: False
  split_batches: True
  ignore_data_skip: True
  dataset_kwargs:
    skip_prepare_dataset: True
