# @package _global_
defaults:
  # Use the main verl PPO trainer configuration as base, null out parameters that we do not use, set sensible defaults
  - _self_

# Critic parameters. The PPO critic model is the same as the actor model.
critic:
  # PPO uses a critic model, GRPO does not
  optim:
    lr: 1e-5
  model:
    path: ${model_path}

    # Use same lora config as actor
    lora_alpha: ${actor_rollout_ref.model.lora_alpha}
    lora_rank: ${actor_rollout_ref.model.lora_rank}
    target_modules: ${actor_rollout_ref.model.target_modules}

    # Performance parameters
    use_remove_padding: ${performance.use_remove_padding}
    enable_gradient_checkpointing: ${performance.offloading.enable_gradient_checkpointing}
    enable_activation_offload: ${performance.offloading.enable_activation_offload}
    fsdp_config:
      forward_prefetch: ${performance.fsdp_forward_prefetch}

  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
  rollout_n: ${actor_rollout_ref.rollout.n}
  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
  loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}

  # Performance parameters
  strategy: ${performance.strategy}
  use_dynamic_bsz: ${performance.use_dynamic_bsz}
  ppo_micro_batch_size_per_gpu: ${performance.base_micro_bsz}
  forward_micro_batch_size_per_gpu: ${performance.micro_batch_size.forward}
  forward_max_token_len_per_gpu: ${performance.dynamic_batch_size.critic_max_token_len_per_gpu}
  ppo_max_token_len_per_gpu: ${performance.dynamic_batch_size.critic_max_token_len_per_gpu}

trainer:
  logger: ["console", "wandb"]
  val_before_train: false
  n_gpus_per_node: ${n_gpus}
  nnodes: 1

  # Project name for experiment tracking (e.g., wandb)
  project_name: troll_iclr26
  entity: TODO!!  # Your wandb entity

  # Experiment name for run identification in tracking tools
  experiment_name: ${exp_name}

  # Save and load configuration
  resume_mode: disable  # start each run from scratch unless specified otherwise.
  # Set to "resume_from_path" to load from checkpoint at trainer.resume_from_path
  resume_from_path: ${trainer.default_local_dir}
  max_ckpt_to_keep: 1  # Maximum number of checkpoints to keep. Only keep the latest checkpoint in each run.
  max_actor_ckpt_to_keep: ${trainer.max_ckpt_to_keep}
  max_critic_ckpt_to_keep: ${trainer.max_ckpt_to_keep}
  # A new run starting from this checkpoint will *not* delete the old one, even if it saves to the same dir.
  default_local_dir: ${base_logging_dir}/checkpoints/${exp_name}  # Save each run to its (hopefully unique) exp name
  default_hdfs_dir: ~  # Should be None for save&load

  total_epochs: ???
  save_freq: ???
  test_freq: ???
  log_val_generations: ???

algorithm:
  # Only used if use_pf_ppo is True
  use_pf_ppo: False
  pf_ppo: ~
#    weight_pow: ~
#    reweight_method: ~

  # Only used if use_kl_in_reward is True
  use_kl_in_reward: ???
  kl_ctrl:
    type: fixed
    kl_coef: 0.001  # KL coefficient for the reward, not the loss
  kl_penalty: ~

  # Only used for GRPO and DrGRPO

  adv_estimator: ~  # Options: "gae", "grpo"
  norm_adv_by_std_in_grpo: True


  filter_groups:
    enable: False
    max_num_gen_batches: 10


reward_model:
  enable: False
  model: ~
  strategy: ${performance.strategy}
  use_dynamic_bsz: ${performance.use_dynamic_bsz}
  micro_batch_size_per_gpu: ${performance.base_micro_bsz}
  forward_max_token_len_per_gpu: ${performance.dynamic_batch_size.forward_max_token_len_per_gpu}

actor_rollout_ref:
  batch_config:
    base_micro_batch_size: ${performance.base_micro_bsz}
    rollout_micro_batch_size: ${performance.micro_batch_size.rollout}


  actor:
    policy_loss:
      # Layer type (sparse/dense) is now determined automatically based on loss_mode:
      # - trpl_dense uses DtrplLayer (dense)
      # - trpl, trpl_debug, trpl_seq use SdtrplLayer (sparse)
      loss_mode: "vanilla"  # use "trpl" for TRPL
      clip_cov_lb: ~
      clip_cov_ub: ~
      ppo_kl_coef: ~
      kl_cov_ratio: ~
      clip_cov_ratio: ~
    sparsify_logits:
      use: True
      keep_selected_token: True  # If true, always keep the logits of the selected token (otherwise, it may be zeroed out)
      default: 1e-12
        # will be set sparsed out and set to default
      threshold: 1e-6  # Probability threshold for sparsification. Everything below this threshold (except for the selected)
      total_default_mass: True # keep enough of largest logits for at least 1-${.threshold} (e.g. 99.9%) of all probability
      total_default_keep_maxnum: 256 # limit for number of logits, even more than total_default_mass is dropped
      chunk_size: -1 # don't chunk
    trpl:
      kl_bound: 0.05  # bound between old and new policy
      alpha: 1.0  # weight of the kl distillation loss
      project_full_sequence: False  # Only implemented for "trpl_seq" loss.
      # If True, will project the full sequence (of shape seq_len*#logits) to that of the reference policy.
      # This essentially smoothens the bound over the sequence. Needs a higher kl bound, usually
      opt_config:
          num_points: 8
          max_steps: 20
          x_threshold: 1e-5
          lower: 1e-7
          upper: 1e2
    optim:
      lr: 1e-6  # Learning rate for the actor
      weight_decay: 0.0  # Weight decay for the actor optimizer
      warmup_style: constant
    loss_agg_mode: token-mean  # token-mean, seq-mean-token-sum-norm, seq-mean-token-sum
    grad_clip: 1  # Gradient clipping value
    clip_ratio: 0.2  # PPO clip ratio
    clip_ratio_low: ${actor_rollout_ref.actor.clip_ratio}
    clip_ratio_high: ${actor_rollout_ref.actor.clip_ratio}
    clip_ratio_c: 10 # Clips the loss with -clip_ratio_c*advantages. Seems to take effect very sparsely
    ppo_epochs: 1  # How many ppo steps to do on each train_batch sample
    ppo_mini_batch_size: 32
    ppo_micro_batch_size_per_gpu: ${performance.base_micro_bsz}
    strategy: ${performance.strategy}
    use_dynamic_bsz: ${performance.use_dynamic_bsz}
    ppo_max_token_len_per_gpu: ${performance.dynamic_batch_size.ppo_max_token_len_per_gpu}
    fsdp_config:
      forward_prefetch: ${performance.fsdp_forward_prefetch}
      offload_policy: False  # Whether to offload parameters to CPU when not in use. Saves GPU memory, but slower.

    # for PPO, no KL loss
    use_kl_loss: ???
    kl_loss_coef: 0.001
    kl_loss_type: ~

  model:
    path: ${model_path}

    # LORA subconfig
    lora_alpha: 0  # No lora!
    lora_rank: 0  # No lora!
    target_modules: ~  # No lora!

    # Performance parameters
    use_remove_padding: ${performance.use_remove_padding}
    enable_gradient_checkpointing: ${performance.offloading.enable_gradient_checkpointing}
    enable_activation_offload: ${performance.offloading.enable_activation_offload}

  ref:
    sparsify_logits: ${actor_rollout_ref.actor.sparsify_logits}
      #use: ${actor_rollout_ref.actor.sparsify_logits.use}
      #threshold: ${actor_rollout_ref.actor.sparsify_logits.threshold}
      #default: ${actor_rollout_ref.actor.sparsify_logits.default}
    log_prob_micro_batch_size_per_gpu: ${performance.base_micro_bsz}
    strategy: ${performance.strategy}
    log_prob_use_dynamic_bsz: ${performance.use_dynamic_bsz}
    log_prob_max_token_len_per_gpu: ${performance.dynamic_batch_size.log_prob_max_token_len_per_gpu}
    fsdp_config:
      forward_prefetch: ${performance.fsdp_forward_prefetch}

  rollout:
    n: 8

    name: vllm

    log_prob_micro_batch_size_per_gpu: ${performance.micro_batch_size.rollout}
    gpu_memory_utilization: ${performance.gpu_memory_utilization}
    tensor_model_parallel_size: ${performance.tensor_model_parallel_size}
    max_num_batched_tokens: ${performance.max_num_batched_tokens}
    log_prob_use_dynamic_bsz: ${performance.use_dynamic_bsz}
    log_prob_max_token_len_per_gpu: ${performance.dynamic_batch_size.log_prob_max_token_len_per_gpu}
    do_sample: True # default but explicit here!

    # We don't do multi-turn
    multi_turn:
      enable: False

custom_reward_function:
  name: compute_score
  path: null
