              
                                                      
                                         

from itertools import chain
import argparse
import dataclasses
import json
import os
import sys
import time

from megatron.training.arguments import _print_args
from megatron.core.transformer.enums import AttnBackend
from megatron_datasets.args import _add_dataset_extra_args


def gpatch_extra_args(parser):
    group = parser.add_argument_group(title='train extra args')

               
    group.add_argument("--px-use-fast-tokenizer",
                       action='store_true',
                       help="whether to use fast tokenizer or not")
    group.add_argument("--tokenizer-seed", type=int, default=1111, help="")

             
    group.add_argument("--px-do-eval-per-domain",
                       action='store_true',
                       help="whether do eval ovel per domain nor not")
    group.add_argument("--px-clear-train-data-consuming-progresses",
                       action='store_true',
                       help="whether to clear train data comsuming when restore from checkpoint")
    group.add_argument("--px-reset-dataloader-at-start-of-eval",
                       action='store_true',
                       help="reset dataloader at start of eval")
    group.add_argument("--px-dataloader-prefetch-factor",
                       type=int,
                       default=None,
                       help="dataloader prefetch factor")
    group.add_argument("--px-task-name", type=str, default=None, help="heartbeat task name")

                 
    group.add_argument("--px-inputs-pad-to-longest",
                       action='store_true',
                       help="allow train with dynamin length")
    group.add_argument("--px-pad-to-multiple-of",
                       type=int,
                       default=2048,
                       help="the multiple of longest padding")
    group.add_argument('--px-smart-padding-buffer-size',
                       type=int,
                       default=2048,
                       help='unordered-examples-buffer-size if px-inputs-pad-to-longest enable')

           
                                                           
                                                         
    group.add_argument('--model-arch',
                       type=str,
                       default=None,
                       choices=[
                           'welm_19b',
                           "llama",
                           "bog",
                           "yi_9b",
                           "qwen2-72b",
                           "baichuan-7b",
                           'dsr1-distill-qwen2.5-32b',
                           'qwen2.5-1.5b',
                           'bog-moe',
                           "qwen2.5-math-rm-72b",
                           "qwen2.5-math-1.5b",
                           'welm_moe_32b',
                           'qwen1.5-moe',
                           "qwen2vl",
                           "qwen2.5vl",
                           "qwq-32b",
                           "gemma3",
                           'qwen3',
                           'qwen3-moe',
                       ],
                       help='model arch type')
    group.add_argument("--adj-rope-w-pos-ids",
                       action='store_true',
                       help="adjust rope with position ids")
                  
    group.add_argument("--moe-norm-topk-prob",
                       action='store_true',
                       help="whether to norm probs when using moe act as softmax")
    group.add_argument("--moe-norm-topk-prob-eps",
                       type=float,
                       default=0.0,
                       help="the eps of deno when setting --moe-norm-topk-probs ")
    group.add_argument("--moe-pad-with-random-token",
                       action='store_true',
                       help="whether to pad with random token or not")
              
                                                      
                                  
                                   
                                                                                                    
                                                                     
    group.add_argument("--padded-vocab-size",
                       type=int,
                       default=None,
                       help='''Nvidia M-core 里这个东西是用户 vocab size 加上 tp-size 算出来的，不过 qwen 的
                       是固定的较大值，而不是与 tokenizer 直接相关的，可能是 ali pai 内部用了一个较大的 tp-size
                       之类的。
                       ''')

    group.add_argument("--num-gpus-per-node", type=int, default=8)
    group.add_argument("--context-parallel-heads-kv-stride",
                       type=int,
                       default=None,
                       help='llama3 context parallel heads_k_stride')
    group.add_argument("--image-token-id", type=int, default=None, help='image token id')
    group.add_argument('--cli-arg-yaml-cfgs',
                       type=str,
                       default=None,
                       nargs='*',
                       help='yaml config files, each storing part of cli args')

                                 
    parser = _add_dataset_extra_args(parser)
    parser = _add_ckpt_conv_args(parser)
    parser = _add_rl_args(parser)
    parser = _add_ema_args(parser)
    parser = _add_lora_args(parser)
    parser = _add_multi_modal_args(parser)
    parser = _add_vllm_args(parser)
    parser = _add_monitor_args(parser)

    return parser


def _add_ckpt_conv_args(parser):
    group = parser.add_argument_group(title='px ckpt conv args')
    group.add_argument('--load-model-provider',
                       type=str,
                       default=None,
                       help='module that defines model_provider')
    group.add_argument('--save-model-provider',
                       type=str,
                       default=None,
                       help='module that defines model_provider')
    group.add_argument('--load-rm-provider',
                       type=str,
                       default=None,
                       help='module that defines model_provider')
    return parser


def _add_monitor_args(parser):
    group = parser.add_argument_group(title='args for monitoring')
    group.add_argument('--do-monitor', action='store_true', help='the world size of the monitor server')
    group.add_argument('--monitor-server-ip',
                       type=str,
                       default='127.0.0.1',
                       help='the ip of the monitor server')
    group.add_argument('--monitor-port',
                       type=int,
                       default=62000,
                       help='the port of the monitor server')
    group.add_argument(
        '--monitor-interval',
        type=int,
        default=30,
        help='the interval in seconds for the client to ping the server',
    )
    group.add_argument(
        '--monitor-max-time-wo-progress',
        type=int,
        default=3600,
        help='if no progress is detected for too long time, the worker will be taken down',
    )
    group.add_argument('--auto-set-finetune-arg',
                       action='store_true',
                       help='turn on to allow auto reseting args such as checkpoint dir')
    return parser


def _add_rl_args(parser):
    group = parser.add_argument_group(title='rl args')
         
    group.add_argument("--dpo", action='store_true', help="dpo")
    group.add_argument('--dpo-beta', type=float, default=0.1, help='dpo beta')
    group.add_argument('--dpo-label-smoothing', type=float, default=0., help='dpo label smoothing')
    group.add_argument('--dpo-ftx-gamma', type=float, default=0., help='dpo ftx gamma')
    group.add_argument('--dpo-pair-gamma', type=float, default=1., help='dpo pair gamma')
    group.add_argument('--dpo-reward-models-cnt',
                       type=int,
                       default=0,
                       help='the reward models count')
    group.add_argument('--dpo-margin-keys',
                       type=str,
                       nargs='*',
                       default=[],
                       help='the margin keys, is a list')
    group.add_argument('--dpo-model-using',
                       type=str,
                       default='both',
                       choices=['policy', 'ref', 'both'],
                       help='1. policy: only using the policy model, support for train and infer'
                       '2. ref: only using the reference model, only support for infer'
                       '3. both: using the policy and reference model, support for train and infer')
    group.add_argument("--dpo-golden-loss", action='store_true', help="using the dpo golden loss")
    group.add_argument("--dpo-policy-ref-model-cnt",
                       type=int,
                       default=1,
                       help="the moel cnt when using dpo")

                             
    group.add_argument('--dpo-gen-margin-key',
                       type=str,
                       default=None,
                       help='the gen margin key, for modpo gen margin')
    group.add_argument('--dpo-gen-margin-path',
                       type=str,
                       default=None,
                       help='the gen margin path, for modpo gen margin')
                       
    group.add_argument('--dpo-reward-model-paths',
                       type=str,
                       nargs='*',
                       default=[],
                       help='the reward model path list in mpdpo, only using for convert models'
                       'if empty that means the dpo')
    group.add_argument('--dpo-loss-of-orion', action='store_true', default=False)
    group.add_argument('--orpo-loss', action='store_true', default=False)
    group.add_argument('--dpo-golden-margin', type=float, default=0.5, help='dpo-golden-margin')
        
    group.add_argument('--rm-sentence',
                       action='store_true',
                       help="use sentence RM (equivalent to value model)")
    group.add_argument('--rm-use-avg-pool', action='store_true', help="RM use avg pool")
    group.add_argument('--rm-use-focal-loss', action='store_true', help="RM use focal loss")
    group.add_argument('--rm-focal-loss-lambda',
                       type=float,
                       default=0.02,
                       help='RM focal loss coeff')
    group.add_argument('--rm-focal-loss-gamma', type=float, default=2., help='RM focal loss gamma')
    group.add_argument('--rm-focal-loss-range', type=float, default=5., help='RM focal loss range')
    group.add_argument('--rm-use-triplet-loss', action='store_true', help="RM use triplet loss")
    group.add_argument('--rm-golden-margin', type=float, default=0.5, help='rm golden margin')
    group.add_argument('--rm-triplet-coef', type=float, default=0.5, help='rm triplet coef')
    group.add_argument(
        '--rm-triplet-focal-coef',
        type=float,
        nargs='*',
        default=None,
        help='coef of rm triplet focal loss. follow by [loss_cr_coef, loss_gc_coef, loss_gr_coef]')
    group.add_argument('--rm-focal-loss-coef',
                       type=float,
                       nargs='*',
                       default=None,
                       help=f'coef of rm focal loss. follow by '
                       f'[rm-focal-loss-lambda, rm-focal-loss-gamma, rm-focal-loss-range]')
    group.add_argument('--rm-focal-loss-ranking-coef',
                       type=float,
                       nargs='*',
                       default=None,
                       help=f"the ranking loss coef of rm focal loss, follow by "
                       f" [ranking_cr_coef, ranking_gc_coef, ranking_gr_coef]")
    group.add_argument('--rm-tokenizer-models',
                       nargs='*',
                       default=None,
                       help='RM model 的 sentencepiece tokenizer models')
    group.add_argument('--actor-tokenizer-model',
                       type=str,
                       default=None,
                       help='Sentencepiece tokenizer model of actor')

         
    group.add_argument('--rl-role',
                       type=str,
                       default=None,
                       help='one of actor, sampler, rm, rm-and-critic')
    group.add_argument('--infer-engine-impl',
                       type=str,
                       default="vllm",
                       choices=['vllm', 'sglang', 'mcore'],
                       help="sampler infer engine backend")
    group.add_argument('--sampler-dist-init-addrs', type=str, nargs='*', default=[])

                 
    group.add_argument(
        '--ppo-rollout-micro-batch-size',
        type=int,
        default=1,                                    
        help='ppo rollout micro batch size')
    group.add_argument(
        '--ppo-rollout-global-batch-size',                                                     
        type=int,
        default=512,
        help='ppo rollout global batch size')
    group.add_argument('--ppo-resp-seq-len', type=int, default=512, help='ppo resp seq len')
    group.add_argument('--ppo-rollout-batch-seq-length',
                       type=int,
                       default=None,
                       help='ppo rollout batch seq length')
    group.add_argument('--ppo-rollout-pad-to-multiple-of',
                       type=int,
                       default=None,
                       help='ppo rollout pad to multiple of')
    group.add_argument('--ppo-rollout-debug-pad-to-prompt-seq-len',
                       action='store_true',
                       help='ppo rollout debug pad to prompt seq len')
    group.add_argument('--ppo-calc-adv-per-token-rewards-factor',
                       type=float,
                       default=1.0,
                       help='''
    一般的 PPO 没有 per token rewards，lucky 老板增加了这个算法改动。在计算 advantage 的时候，
    per token rewards x 这个 factor，然后 + 到 advantage 上。
    ''')

                        
    group.add_argument('--ppo-smart-pad-infer',
                       action='store_true',
                       help='enable ppo smart_pad for infer')
    group.add_argument('--ppo-smart-pad-train',
                       action='store_true',
                       help='enable ppo smart-pad for train')
    group.add_argument('--ppo-train-dynamic-mbs-target-seq',
                       type=int,
                       default=None,
                       help='ppo train dynamic mbs target seq')
    group.add_argument('--ppo-train-dynamic-mbs-limit',
                       type=int,
                       default=None,
                       help='ppo train dynamic mbs limit')
    group.add_argument('--ppo-pack-seq',
                       action='store_true',
                       help='enable pack seq for ppo train')
    group.add_argument('--ppo-wecube-report',
                       action='store_true',
                       help='whether to report ppo metrics to wecube')

                       
    group.add_argument('--ppo-critic-data-parallel-size',
                       type=int,
                       default=1,
                       help='ppo critic 的数据并行度')
    group.add_argument('--ppo-critic-tensor-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo critic 的 tensor 并行度')
    group.add_argument('--ppo-critic-pipeline-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo critic 的 pipeline 并行度')
    group.add_argument('--ppo-critic-node-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='''ppo critic 的 node ip，每个 critic pod 提供一个 ip。
    例如两个 gemini pod，ip 分别是 1.1.1.2 和 1.1.1.3，那么 `--ppo-critic-ips 1.1.1.2 1.1.1.3`。
    ''')
    group.add_argument('--ppo-critic-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='ppo critic 的 ip，每个 critic 提供一个 ip。自动生成，不用填。')
    group.add_argument('--ppo-critic-ports',
                       type=int,
                       nargs='*',
                       default=[63000],
                       help='''
    PPO critic port，每个 worker 一个 port。

    如果只有一个 port，会自动 +1 +1 +1。比如 dp-size=4，`--ppo-critic-ports 31000`，
    那么实际会分别占用 31000，31001，31002，3003 这 4 个 port。这么设计是防止 pod 内 端口冲突。

    如果用户提供 `--ppo-critic-ports 31001 31002 10086 1234`， 那么选用户的入参。
    ''')
    group.add_argument(
        '--ppo-rm-critic-server-timeout-keep-alive',
        type=int,
        default=5,                                       
        help='PPO server timeout ka')
    group.add_argument('--ppo-rm-critic-client-timeout',
                       type=int,
                       default=60,
                       help='PPO client timeout')
    group.add_argument("--combine-rm-and-critic-server",
                       action='store_true',
                       help="ppo combine rm and critic server")

                        
    group.add_argument("--ppo-standalone-sampler",
                       action='store_true',
                       help="use standalone sampling server")
    group.add_argument('--ppo-sampler-data-parallel-size',
                       type=int,
                       default=1,
                       help='ppo sampler 的数据并行度')
    group.add_argument('--ppo-sampler-tensor-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo sampler 的 tensor 并行度')
    group.add_argument('--ppo-sampler-pipeline-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo sampler 的 pipeline 并行度')
    group.add_argument('--ppo-sampler-node-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='''ppo sampler 的 node ip，每个 sampler pod 提供一个 ip。
    例如两个 gemini pod，ip 分别是 1.1.1.2 和 1.1.1.3，那么 `--ppo-sampler-ips 1.1.1.2 1.1.1.3`。
    ''')
    group.add_argument('--ppo-sampler-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='ppo sampler 的 ip，每个 sampler 提供一个 ip。自动生成，不用填。')
    group.add_argument('--ppo-sampler-ports',
                       type=int,
                       nargs='*',
                       default=[64000],
                       help='''
    PPO sampler port，每个 worker 一个 port。

    如果只有一个 port，会自动 +1 +1 +1。比如 dp-size=4，`--ppo-sampler-ports 31000`，
    那么实际会分别占用 31000，31001，31002，3003 这 4 个 port。这么设计是防止 pod 内 端口冲突。

    如果用户提供 `--ppo-sampler-ports 31001 31002 10086 1234`， 那么选用户的入参。
    ''')
    group.add_argument('--ppo-sampler-server-timeout-keep-alive',
                       type=int,
                       default=5,
                       help='PPO server timeout ka')
    group.add_argument('--ppo-sampler-client-timeout',
                       type=int,
                       default=60,
                       help='PPO client timeout')
    group.add_argument('--ppo-sampler-client-update-timeout',
                       type=int,
                       default=60,
                       help='PPO client update timeout')
    group.add_argument('--ppo-step-update-sampler-interval',
                       type=int,
                       default=1,
                       help='interval (ppo step) to update sampler')
    group.add_argument('--ppo-sampler-save',
                       type=str,
                       default='ppo-sampler-save',
                       help='''
    ppo actor 更新 sampler 的权重中转目录。

    注意：必须是 DFS（否则 sampler 无法访问）。
    默认在 `${PWD}/ppo-sampler-save/${TMP}`，如果你在 cephfs 执行训练，那么并不需要关心这个参数。
    ''')
    group.add_argument('--ppo-step-update-sampler-delay',
                       type=int,
                       default=0,
                       help='''
    sampler 到达 ppo-step-update-sampler-interval 之后，会进行更新，这里加上一个 delay。sampler 会稍后等一会，
    再 generate 一波，这样 actor 在导出之后，立刻就能取得一波 samples。然后 delay 个 ppo step 后再更新权重。
    提高吞吐，但会有一点 weights 不 match，我觉得小问题。
    ''')

                      
    group.add_argument('--ppo-actor-node-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='''ppo actor 的 node ip，每个 actor pod 提供一个 ip。
    例如两个 gemini pod，ip 分别是 1.1.1.2 和 1.1.1.3，那么 `--ppo-actor-node-ips 1.1.1.2 1.1.1.3`。
    ''')
    group.add_argument('--ppo-actor-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='ppo actor 的 ip，每个 actor 提供一个 ip。自动生成，不用填。')
    group.add_argument('--ppo-actor-ports', type=int, nargs='*', default=[65000], help="")

               
    group.add_argument('--ppo-max-epochs', type=int, default=1, help='ppo max epochs')
    group.add_argument('--ppo-max-epochs-2', type=int, default=1, help='ppo max epochs 2')
    group.add_argument('--ppo-step-per-epoch', type=int, default=None, help='ppo iter per epoch')
    group.add_argument('--train-iters-each-rollout',
                       type=int,
                       default=-1,
                       help='train iters each rollout')
    group.add_argument('--ppo-initial-policy-kl-penalty',
                       type=float,
                       default=0.01,
                       help='ppo compute init policy kl')
    group.add_argument("--ppo-use-absolute-kl", action='store_true', help='ppo use abs kl')
    group.add_argument('--ppo-discount-factor',
                       type=float,
                       default=1.0,
                       help='ppo discount factor, ppo GAE gamma γ')
    group.add_argument('--ppo-gae-lambda', type=float, default=0.95, help='ppo gae lambda λ')
    group.add_argument('--ppo-normalize-advantages', action='store_true', help='ppo normalize adv')
    group.add_argument('--ppo-entropy-bonus',
                       type=float,
                       default=0.0,
                       help='ppo actor entropy bonus')
    group.add_argument('--ppo-ratio-eps', type=float, default=0.2, help='ppo actor ratio eps')
    group.add_argument(
        '--ppo-dual-clip-ratio-c',
        type=float,
        default=None,
        help='Dual-clip PPO, should be greater than 1.0, detail in: https://arxiv.org/pdf/1912.09729'
    )
    group.add_argument('--ppo-clamp-kl-val',
                       type=float,
                       default=None,
                       help='clamp kl in (-x, +x) for numerical stability')
    group.add_argument('--ppo-logps-ratio-clamp',
                       type=float,
                       default=None,
                       help='clamp logps ratio in (-x, +x) for numerical stability')
    group.add_argument('--ppo-rollout-temperature',
                       type=float,
                       default=1.0,
                       help='ppo actor rollout temperature')
    group.add_argument('--ppo-rollout-repetition-penalty',
                       type=float,
                       default=1.0,
                       help='ppo actor rollout repetition penalty')
    group.add_argument('--ppo-rollout-top-k', type=int, default=1, help='ppo actor rollout top k')
    group.add_argument('--ppo-rollout-top-p',
                       type=float,
                       default=0.0,
                       help='ppo actor rollout top p')
    group.add_argument('--ppo-rollout-top-p-decay',
                       type=float,
                       default=0.0,
                       help='ppo actor rollout top p decay')
    group.add_argument('--ppo-rollout-top-p-bound',
                       type=float,
                       default=0.0,
                       help='ppo actor rollout top p bound')
    group.add_argument('--ppo-to-offload-adam-states',
                       action='store_true',
                       help='ppo to offload adam states')
                
    group.add_argument(
        '--ppo-loss-clip-val',                                                       
        type=float,
        default=0.2,
        help='ppo loss clip val')
    group.add_argument('--ppo-enable-standardization',
                       action='store_true',
                       help='reward_standardization_enable')
    group.add_argument('--ppo-reward-mean', type=float, default=0., help='ppo reward mean')
    group.add_argument('--ppo-reward-var', type=float, default=1., help='ppo reward var')
    group.add_argument('--ppo-per-token-reward-mean',
                       type=float,
                       default=0.,
                       help='ppo per token reward mean')
    group.add_argument('--ppo-per-token-reward-var',
                       type=float,
                       default=1.,
                       help='ppo per token reward var')
    group.add_argument('--ppo-reward-clip-val',
                       type=float,
                       default=1.,
                       help='ppo reward clip range (-x, +x)')
    group.add_argument('--ppo-reward-count', type=int, default=0, help='ppo reward count')
    group.add_argument('--ppo-per-token-reward-count',
                       type=int,
                       default=0,
                       help='ppo per token reward count')
    group.add_argument('--ppo-reward-scalings',
                       type=float,
                       nargs='*',
                       default=[],
                       help='scales ppo reward by this factor')
    group.add_argument('--ppo-reward-len-penalty-coef',
                       type=float,
                       default=0.,
                       help='ppo reward len penalty coef')
    group.add_argument('--ppo-reward-len-penalty-mean',
                       type=float,
                       default=256,
                       help='ppo reward len penalty mean')
    group.add_argument('--ppo-reward-len-penalty-std',
                       type=float,
                       default=128,
                       help='ppo reward len penalty std')
    group.add_argument('--ppo-value-truncate-head',
                       action='store_true',
                       help='''
    将 value 的 head truncate 掉，而不是 truncate 掉 last token（eos token）。
    nemo 和 trl 都是这么 truncate 的，但分析起来非常诡异，所以这里加了一个开关：
    1. values truncate 掉第 first token，而不是 first token
    2. critic server 在 loss 计算的时候，truncate 掉 first token，而不是 last token。
    ''')
         
    group.add_argument('--ppo-disable-tqdm', action='store_true', help='disable tdqm')
    group.add_argument('--ppo-display-rollout-generation',
                       action='store_true',
                       help='ppo display rollout generation')
    group.add_argument('--load-ref',
                       type=str,
                       nargs='*',
                       default=[],
                       help='ppo ref model path (ref / reward), is a list')
    group.add_argument('--save-ref',
                       type=str,
                       nargs='*',
                       default=[],
                       help='a path to store ppo ref model updated, it works only on ppo ref')
    group.add_argument('--rm-ref-factors',
                       type=float,
                       nargs='*',
                       default=[],
                       help='ppo ref model path (ref / reward) factor, is a list')
    group.add_argument('--ppo-auto-calc-args',
                       action='store_true',
                       help='PPO automatically calculate args')
    group.add_argument('--ppo-step-save-interval',
                       type=int,
                       default=-1,
                       help='ppo save global step interval')
    group.add_argument('--ppo-actor-freeze-ppo-steps',
                       type=int,
                       default=0,
                       help='ppo actor no update steps')
    group.add_argument('--freeze-iters', type=int, default=0, help='ppo actor freeze iteration')
    group.add_argument('--gen-left-pad', action='store_true', help='gen left pad')
    group.add_argument('--ppo-logps-fwd-micro-batch-size',
                       type=int,
                       default=1,
                       help='ppo logps fwd micro batch size')
    group.add_argument('--attn-unpad-kv-cache', action='store_true', help='attn unpad kv cache')
    group.add_argument('--ppo-debug-gen', action='store_true', help='ppo debug gen')
    group.add_argument('--ppo-debug-fake-gen', action='store_true', help='ppo debug fake gen')
    group.add_argument('--ppo-debug-fake-rm-critic',
                       action='store_true',
                       help='ppo debug fake rm critic')
    group.add_argument('--ppo-sort-prompts-across-batches',
                       type=int,
                       default=1,
                       help='sort prompts across batches for higher generation throughput')
    group.add_argument('--ppo-rollout-max-prompt-len-diff',
                       type=int,
                       default=256,
                       help='ppo actor prompt bucket len interval')
    group.add_argument('--ppo-rm-mask-prompt',
                       action='store_true',
                       help='ppo actor Reward Model mask the prompt when compute the loss')
    group.add_argument('--gen-term-at-nan',
                       action='store_true',
                       help='terminate if nan in softmax of logits')
    group.add_argument('--ppo-sampling-repeat',
                       type=int,
                       default=1,
                       help='ppo prompt 重复采样次数，如果为 8，会重复采样 8 次')
    group.add_argument('--ppo-sampling-keep',
                       type=int,
                       default=1,
                       help='ppo prompt 重复采样后保留个数，如果为 2，会保留两个结果')

                                                                              
           
        
                                 
        
                 
                                
                                                                                      
                      
    group.add_argument('--ppo-update-ref-w-actor-interval',
                       type=int,
                       default=0,
                       help='constantly update ref with actor')
    group.add_argument('--ppo-update-ref-w-actor-coef',
                       type=float,
                       default=0.,
                       help='update ref with actor coef')
    group.add_argument('--ppo-num-rm', type=int, default=1, help='PPO number of RM')
    group.add_argument('--rm-output-sequence',
                       type=int,
                       nargs='*',
                       default=None,
                       help='''
    一般 rm 只有输出 scalar 一个分数；这个 feat 是因为，
    sentence level 忠实度 rm 的输出不是一个分数，而是 sequence，然后从 seq 里取出最小的一个。
    如果 --rm-output-sequence 0 1 0，表示 rms[0]、rms[2] 不 output sequence，rms[1] output sequence。
    默认关闭。
    ''')
    group.add_argument('--rm-output-scalar',
                       type=int,
                       nargs='*',
                       default=None,
                       help='''
    这个逻辑比较特殊，lucky 将其分成了多种情况：
    1. 没有都 false 的情况；
    2. critic 目前只有 output sequence（即 per token reward）；
    3. reward 默认情况，按照 ppo paper，是 seq 的 reward，也就是 output scalar。
    4. lucky 希望支持两者都为 true，同时输出。

    output sequence 输出 sequence，表示 per token reward，
    output scalar   输出 scalar  ，表示 sequence 的 reward。
    ''')
    group.add_argument('--rm-num-attributes',
                       type=int,
                       default=1,
                       help="the num attributes of reward model")
    group.add_argument('--use-grpo', action='store_true', help='use grpo')
    group.add_argument('--grpo-advantage-epsilon',
                       type=float,
                       default=1e-6,
                       help='-grpo advantage epsilon')
    group.add_argument('--grpo-kl-loss-beta', type=float, default=1e-3, help='grpo kl loss coef')
    group.add_argument('--rm-head-arch',
                       type=str,
                       default='single_layer',
                       choices=['single_layer', 'multi_layers'],
                       help='rm head arch')
    group.add_argument('--ppo-save-first-rollout-data',
                       action='store_true',
                       help='ppo-save-first-rollout-data')
    group.add_argument("--ppo-grpo-reward-type",
                       type=str,
                       default='rm_only',
                       choices=["rm_only", "rule_only", "rm_with_rule"],
                       help="grpo reward type")
    group.add_argument("--ppo-rule-reward-beta",
                       type=float,
                       default=1.0,
                       help="rule reward beta, only use when rm_with_rule")
    group.add_argument("--ppo-rm-reward-alpha",
                       type=float,
                       default=1.0,
                       help="rm reward alpha, only use when rm_with_rule")
    group.add_argument("--grpo-prefetch-samplings",
                       action='store_true',
                       help='GRPO 提前 prefetch samplings')
    group.add_argument("--grpo-rpc-max-retries", type=int, default=3, help="grpo rpc max retries")
    group.add_argument("--ppo-dynamic-sampling", action='store_true', help='dynamic sampling')
    group.add_argument("--ppo-dynamic-sampling-max-replay",
                       type=int,
                       default=2,
                       help="ppo max replay sample times, useful when --ppo-dynamic-sampling")

            
    group.add_argument('--no-use-rm-and-critic',
                       action='store_false',
                       help='Disable tradition BT rm and critic',
                       dest='use_rm_and_critic')
    group.add_argument('--use-gen-rm', action='store_true', help='use gen rm')
    group.add_argument('--ppo-gen-rm-repeat', type=int, default=1, help='gen rm voting times')
    group.add_argument('--ppo-gen-rm-resp-seq-len',
                       type=int,
                       default=None,
                       help='gen-rm resp seq len')
    group.add_argument('--ppo-gen-rm-data-parallel-size',
                       type=int,
                       default=1,
                       help='ppo sampler 的数据并行度')
    group.add_argument('--ppo-gen-rm-tensor-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo sampler 的 tensor 并行度')
    group.add_argument('--ppo-gen-rm-pipeline-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo sampler 的 pipeline 并行度')
    group.add_argument('--ppo-gen-rm-ips',
                       type=str,
                       nargs='*',
                       default=[],
                       help='ppo sampler 的 ip，每个 sampler 提供一个 ip。自动生成，不用填。')
    group.add_argument('--ppo-gen-rm-ports',
                       type=int,
                       nargs='*',
                       default=[64000],
                       help='''
    PPO sampler port，每个 worker 一个 port。

    如果只有一个 port，会自动 +1 +1 +1。比如 dp-size=4，`--ppo-gen-rm-ports 31000`，
    那么实际会分别占用 31000，31001，31002，3003 这 4 个 port。这么设计是防止 pod 内 端口冲突。

    如果用户提供 `--ppo-gen-rm-ports 31001 31002 10086 1234`， 那么选用户的入参。
    ''')
    group.add_argument('--gen-rm-dist-init-addrs', type=str, nargs='*', default=[])
    group.add_argument('--ppo-gen-rm-client-timeout',
                       type=int,
                       default=60,
                       help='PPO gen rm client timeout')
    group.add_argument('--ppo-gen-rm-temperature',
                       type=float,
                       default=1.0,
                       help='ppo actor generation reward temperature')
    group.add_argument('--ppo-gen-rm-repetition-penalty',
                       type=float,
                       default=1.0,
                       help='ppo actor generation reward repetition penalty')
    group.add_argument('--ppo-gen-rm-top-k', type=int, default=1, help='ppo actor rollout top k')
    group.add_argument('--ppo-gen-rm-top-p',
                       type=float,
                       default=0.0,
                       help='ppo actor rollout top p')
    group.add_argument('--ppo-gen-rm-top-p-decay',
                       type=float,
                       default=0.0,
                       help='ppo actor rollout top p decay')
    group.add_argument('--ppo-gen-rm-top-p-bound',
                       type=float,
                       default=0.0,
                       help='ppo actor rollout top p bound')
    group.add_argument('--ppo-debug-update-weight',
                       action='store_true',
                       help="whether to debug update or not")
    group.add_argument('--ppo-debug-sglang-sleep-wakeup-generate',
                       action='store_true',
                       help="whether to debug sglang sleep-wakeup-generate or not")
    group.add_argument('--ppo-fmt-factor',
                       type=float,
                       default=0.1,
                       help='ppo format factor, range in (0, 1.0)')
    group.add_argument('--ppo-custom-rule-file',
                       type=str,
                       default="",
                       help='ppo custom rule file')
    group.add_argument('--ppo-early-swap-model',
                       action="store_true",
                       help="whether to swap model before update or not")
    group.add_argument('--replay-sample',
                       action="store_true",
                       help="replay samples with user specified")

               
    group.add_argument(
        "--ppo-step-eval-interval",
        type=int,
        default=-1,
        help="eval interval in ppo_step, <= 0 disables eval",
    )
    group.add_argument(
        "--ppo-eval-steps",
        type=int,
        default=1,
        help=(
            "How many mini-batches to validate in one eval loop. "
            "If set to -1, the entire eval dataset will be used, "
            "and this arg will be auto-calculated according to --ppo-eval-rollout-global-batch-size."
        ),
    )
    group.add_argument(
        "--ppo-eval-rollout-global-batch-size",
        type=int,
        default=256,
        help="global batch size of an eval mini-batch",
    )
    group.add_argument(
        "--ppo-eval-rollout-micro-batch-size",
        type=int,
        default=1,
        help="eval micro-batch size",
    )
    group.add_argument(
        "--ppo-eval-sampling-repeat",
        type=int,
        default=1,
        help="eval sampling repeat times"
    )
    group.add_argument(
        "--ppo-clip-ratio-low",
        type=float,
        default=None,
        help=("dapo clip-higher low ratio. if set, ppo-ratio-eps will be replaced, ref: "
            "https://github.com/volcengine/verl/blob/main/recipe/dapo/README.md#separated-clip-epsilons---clip-higher"
        ),
    )
    group.add_argument(
        "--ppo-clip-ratio-high",
        type=float,
        default=None,
        help=("dapo clip-higher high ratio. if set, ppo-ratio-eps will be replaced, ref: "
            "https://github.com/volcengine/verl/blob/main/recipe/dapo/README.md#separated-clip-epsilons---clip-higher"
        ),
    )
    group.add_argument(
        '--use-gspo-loss',
        action='store_true',
        help='use gspo loss, ref: https://arxiv.org/pdf/2507.18071',
        )
    group.add_argument(
        "--dapo-overlong-penalty",
        action="store_true",
        help=("enable dapo overlong penalty, ref (include following --dapo-overlong-buffer-len & --dapo-overlong-penalty-factor): "
            "https://github.com/volcengine/verl/blob/main/recipe/dapo/README.md#overlong-reward-shaping"
        ),
    )
    group.add_argument(
        "--dapo-overlong-buffer-len",
        type=int,
        default=0,
        help="dapo overlong penalty buffer length",
    )
    group.add_argument(
        "--dapo-overlong-penalty-factor",
        type=float,
        default=0.0,
        help="dapo overlong penalty factor",
    )
    return parser


def _add_ema_args(parser):
    group = parser.add_argument_group(title='ema args')
         
    group.add_argument("--apply-ema", action='store_true', help="whether to apply ema")
    group.add_argument('--ema-beta',
                       type=float,
                       default=0.99,
                       help='ema beta, if eq to 1.0, freeze ema model')
    group.add_argument('--ema-power', type=float, default=0.75, help='')
    group.add_argument('--ema-inv-gamma', type=float, default=1.0, help='')
    group.add_argument('--ema-update-after-step',
                       type=int,
                       default=100,
                       help='start updating the EMA model only after the specified number of steps')
    group.add_argument('--ema-update-interval',
                       type=int,
                       default=10,
                       help='update the EMA model every specified number of steps')
    group.add_argument('--no-load-ema',
                       action='store_true',
                       help='whether to load ema weights or not')
    group.add_argument('--override-ema-params',
                       action='store_true',
                       help='whether to load ema params or not')
    group.add_argument('--reset-step-of-ema',
                       action='store_true',
                       help='whether to reset ema step or not')
    return parser


def _add_lora_args(parser):
    group = parser.add_argument_group(title='LoRA args')
    group.add_argument('--enable-lora', action='store_true', help="enable LoRA")
    group.add_argument('--lora-r', type=int, default=8, help='lora r')
    group.add_argument('--lora-alpha', type=int, default=16, help='lora alpha')
    group.add_argument("--freeze-lora", action='store_true', help="freeze LoRA")

    return parser


def _add_multi_modal_args(parser):
    group = parser.add_argument_group(title='multi modal args')
    group.add_argument("--mm-enable-no-grad-fn",
                       action='store_true',
                       help="some forward can be no grad")
    group.add_argument("--mm-freeze-llm", action='store_true', help="freeze llm")
    group.add_argument("--mm-freeze-vision-encoder",
                       action='store_true',
                       help="freeze vision encode")
    group.add_argument("--mm-freeze-projector", action='store_true', help="freeze projector")
    group.add_argument(
        "--mm-vision-encode-use-local",
        action='store_true',
        help="vision encode using local implement, default is using TransformerEngine")
    group.add_argument("--mm-vision-encoder-model-arch",
                       type=str,
                       default='idefics2-8b-2k',
                       choices=['idefics2-8b-2k', 'idefics2-8b-4k'],
                       help="vision encode type")

    group.add_argument(
        '--mm-vision-encoder-hf-pretrain-path',
        type=str,
        default=None,
        help='if is None, VE using mlm implement, otherwise using hugging face import')
    return parser


def _add_vllm_args(parser):
    group = parser.add_argument_group(title='vLLM Sampler Arguments')
    group.add_argument('--hf-config-json-path',
                       type=str,
                       default=None,
                       help="the path of the config.json file of a huggingface model")
    group.add_argument('--hf-model-embed-size',
                       type=int,
                       default=None,
                       help="the size of the huggingface model")
    group.add_argument('--ppo-actor-pipeline-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo actor 的 pipeline parallel size')
    group.add_argument('--ppo-actor-expert-model-parallel-size',
                       type=int,
                       default=1,
                       help='ppo actor 的 expert parallel size')
    group.add_argument('--ppo-actor-data-parallel-size',
                       type=int,
                       default=1,
                       help='ppo actor 的 data size')
    group.add_argument('--sampler-gpu-memory-utilization',
                       type=float,
                       default=0.7,
                       help='sampler gpu memory utilization')
    group.add_argument('--gen-rm-gpu-memory-utilization',
                       type=float,
                       default=0.7,
                       help='gen rm gpu memory utilization')
    group.add_argument('--infer-engine-enable-expert-parallel',
                       action='store_true',
                       help='enbale infer engine open ep parallel')
    group.add_argument('--update-weight-max-size-mb',
                       type=int,
                       default=1024,
                       help="the max size when update weight per times. MB")
    group.add_argument('--no-fused-kernel',
                       action="store_true",
                       help="whether no fused kernel or not")
    return parser


def validate_rl_args(args, defaults={}):

    if args.rm_focal_loss_coef is not None:
        assert len(args.rm_focal_loss_coef) == 3
        args.rm_focal_loss_lambda = args.rm_focal_loss_coef[0]
        args.rm_focal_loss_gamma = args.rm_focal_loss_coef[1]
        args.rm_focal_loss_range = args.rm_focal_loss_coef[2]
    if args.rm_use_triplet_loss:
        assert args.rm_triplet_focal_coef is not None and 3 == len(args.rm_triplet_focal_coef)
        assert 1.0 == sum(args.rm_triplet_focal_coef)
        assert args.rm_focal_loss_ranking_coef is not None and 3 == len(
            args.rm_focal_loss_ranking_coef)
    if args.rm_use_focal_loss:
        assert args.rm_focal_loss_ranking_coef is not None and 1 == len(
            args.rm_focal_loss_ranking_coef)
    assert (args.rm_use_focal_loss != args.rm_use_triplet_loss) or \
            (not args.rm_use_focal_loss and not args.rm_use_triplet_loss)

    assert args.ppo_sampling_repeat >= args.ppo_sampling_keep
    if args.use_grpo:
        assert not args.ppo_normalize_advantages
        assert not args.ppo_enable_standardization
                                                 
                                                                   

                                                                
    if args.infer_engine_impl == 'sglang':
        assert args.ppo_step_update_sampler_interval == 1, "sglang infer engine requires it to 1"

                                                        
                                                                                                               
                         
    if args.attention_backend == AttnBackend.fused and args.context_parallel_size > 1:
        assert args.ppo_rollout_pad_to_multiple_of // args.context_parallel_size >= 256, (
            "Using fused attention and context parallel "
            "requires ppo_rollout_pad_to_multiple_of // context_parallel_size >= 256, "
            f"but got {args.ppo_rollout_pad_to_multiple_of=} and {args.context_parallel_size=}"
        )

                      
    _print_args("arguments", args)
    return args
