# Format checks enforced on CI:
# 1. Comments must appear above each field.
# 2. There must be a blank line between each field.
# 3. Inline comments (after a field on the same line) are not allowed.
# 4. Indentation level is respected for nested fields.

# defaults specify the default config from each component
defaults:

  # dp actor config, inheriting from trainer/config/critic/critic.yaml
  - critic

  # load the reference default config, then apply the fields in the current yaml
  - _self_

strategy: fsdp

# optimizer configs
optim:

  # Learning rate
  lr: 1e-5

  # Minimum LR ratio for cosine schedule
  min_lr_ratio: null

  # LR warmup style: "constant" or "cosine"
  warmup_style: constant

# model config for the critic
model:

  # Whether to use shared memory for loading the model
  use_shm: False

  # Offload activations to CPU to reduce GPU memory usage
  enable_activation_offload: False

  # Use remove padding optimization (saves compute)
  use_remove_padding: False

  # FSDP-specific config
  fsdp_config:

    # Whether to offload model parameters to CPU
    param_offload: False

    # Whether to offload optimizer state to CPU
    optimizer_offload: False

    # Only for FSDP2: offload param/grad/optimizer during train
    offload_policy: False

    # Only for FSDP2: Reshard after forward pass to reduce memory footprint
    reshard_after_forward: True

    # Policy for wrapping layers with FSDP
    wrap_policy:

      # Minimum number of parameters to trigger wrapping
      min_num_params: 0

    # Number of GPUs in each FSDP shard group; -1 means auto
    fsdp_size: -1

    # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
    # before the current forward computation.
    forward_prefetch: False

  # Set to positive value to enable LoRA (e.g., 32)
  lora_rank: 0

  # LoRA scaling factor
  lora_alpha: 16

  # LoRA target modules: "all-linear" or list of linear projection layers
  target_modules: all-linear

# Forward-only batch size during inference (global)
forward_micro_batch_size: ${critic.ppo_micro_batch_size}

# Forward-only batch size during inference (per GPU)
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}

# Sequence parallelism size for Ulysses-style model parallelism
ulysses_sequence_parallel_size: 1

# Gradient clipping for critic updates
grad_clip: 1.0