#!/usr/bin/python
'''
TODO import GPU training job package
'''

PATH_TO_VERL = "Path/To/Your/VeRL/Directory" # Path to verl repository
data_root = "Path/To/Your/Data/Directory"
PROJECT_NAME="DeepScaleR-AIME24"
EXPERIMENT_NAME="qwen3-1.7b-DeepScaleR-AIME24-baseline-dapo"
TRAIN_DATA = "Path/To/Your/Original/Train/Data"
EVAL_DATA = "Path/To/Your/EvalData"
MODEL_PATH="Path/To/Qwen3-1.7B"
CHECKPOINTS_DIR=f"{data_root}/experiment_check/{PROJECT_NAME}/{EXPERIMENT_NAME}"
TASK_NAME = f'{PROJECT_NAME}-{EXPERIMENT_NAME}'
IMAGE_URL = "VeRL Image"

import pathlib
pathlib.Path(CHECKPOINTS_DIR).mkdir(parents=True, exist_ok=True)

N_nodes = 1
gpu_nun_per_node = 8
total_gpus = N_nodes * gpu_nun_per_node

# TRAIN ===========================
mini_bsz = N_nodes * 16
train_bsz = mini_bsz * 1

bsz_per_gpu = 4     # batch per gpu
rl_ = 8             # response len (k)

prompt_len=1024 * 1
response_len=1024 * rl_
token_len_per_gpu=1024 * (bsz_per_gpu * (rl_ + 1))
max_num_batched_tokens=prompt_len + response_len

# rollout n---------
roll_n = 16
val_n = 8
temperature = 1.2
temperature_val = 0.6
# ------------------

# parallel----------
tp_size=1
sp_size=1
offload = True
# ------------------


# reward -----------
use_kl_in_reward=False
kl_coef=0.001
# ------------------


# loss-------------
use_kl_loss = True
kl_loss_coef = 0.001
kl_loss_type = 'low_var_kl'
# kl_loss_type = 'ada_coef_kl'
# loss_type = 'neg_zero'
loss_agg_mode = "token-mean"
# -----------------

# -----------------

# DAPO ------------
adv_estimator = 'grpo'
enable_filter_groups = True
filter_groups_metric = 'acc'
max_num_gen_batches = -1            # algorithm.filter_groups.max_num_gen_batches
gen_prompt_bsz = train_bsz          # data.gen_batch_size
clip_ratio_low = 0.2
clip_ratio_high = 0.28
enable_overlong_buffer = False
overlong_buffer_len = 0
overlong_penalty_factor = 0.0
# ----------------
# TRAIN ===========================

cpu_cores = min(8*gpu_nun_per_node, 128)
memory = min(1024*96*gpu_nun_per_node, 1536*1024)

'''
Training Cluster Settings
'''

pre_cmd = f"""
[[Emvironment Setup Command]]
cd {PATH_TO_VERL} &&\
pip install rouge jieba math-verify[antlr4_9_3] &&\
pip install -e . 
"""

training_command = f"""
python3 -m recipe.dapo.main_dapo \
    data.train_files={TRAIN_DATA} \
    data.val_files={EVAL_DATA} \
    data.train_batch_size={train_bsz} \
    data.gen_batch_size={gen_prompt_bsz} \
    data.max_prompt_length={prompt_len} \
    data.max_response_length={response_len} \
    data.filter_overlong_prompts=True \
    data.shuffle=False \
    algorithm.adv_estimator={adv_estimator} \
    algorithm.use_kl_in_reward={use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef={kl_coef} \
    algorithm.filter_groups.enable={enable_filter_groups} \
    algorithm.filter_groups.max_num_gen_batches={max_num_gen_batches} \
    algorithm.filter_groups.metric={filter_groups_metric} \
    actor_rollout_ref.model.path={MODEL_PATH} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size={mini_bsz} \
    actor_rollout_ref.actor.use_dynamic_bsz=True \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu={token_len_per_gpu} \
    actor_rollout_ref.actor.clip_ratio_low={clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high={clip_ratio_high} \
    actor_rollout_ref.actor.use_kl_loss={use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef={kl_loss_coef} \
    actor_rollout_ref.actor.kl_loss_type={kl_loss_type} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size={sp_size} \
    actor_rollout_ref.actor.fsdp_config.param_offload={offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload={offload} \
    +actor_rollout_ref.rollout.enable_prefix_caching=False \
    actor_rollout_ref.rollout.max_num_batched_tokens={max_num_batched_tokens} \
    actor_rollout_ref.rollout.tensor_model_parallel_size={tp_size} \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.temperature={temperature} \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
    actor_rollout_ref.rollout.n={roll_n} \
    actor_rollout_ref.rollout.val_kwargs.n={val_n} \
    actor_rollout_ref.rollout.val_kwargs.temperature={temperature_val} \
    actor_rollout_ref.ref.fsdp_config.param_offload={offload} \
    reward_model.reward_manager=dapo \
    reward_model.overlong_buffer.enable={enable_overlong_buffer} \
    reward_model.overlong_buffer.len={overlong_buffer_len} \
    reward_model.overlong_buffer.penalty_factor={overlong_penalty_factor} \
    trainer.val_before_train=True \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name={PROJECT_NAME} \
    trainer.experiment_name={EXPERIMENT_NAME} \
    trainer.default_local_dir={CHECKPOINTS_DIR} \
    trainer.n_gpus_per_node={gpu_nun_per_node} \
    trainer.max_actor_ckpt_to_keep=1 \
    trainer.nnodes={N_nodes} \
    trainer.save_freq=5 \
    trainer.test_freq=5 \
    trainer.total_epochs=10 $@ 
"""

env_vars = {
    "RAY_memory_monitor_refresh_ms": "0",
    "RAY_memory_usage_threshold": "0.99",
    "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
    "PYTHONPATH": f"{PATH_TO_VERL}:$PYTHONPATH",
    "WANDB_MODE":"offline",
    "WANDB_DIR":f"/Path/To/Your/Wandb/{PROJECT_NAME}/{EXPERIMENT_NAME}"
}

cd_cmd = f"cd {PATH_TO_VERL}"
env_cmd = " && ".join(f"export {k}={v}" for k, v in env_vars.items())
training_command = " ".join(line for line in training_command.strip().split() if line.strip())
command = f"pip list && {cd_cmd} && {env_cmd} && {training_command}"


'''
TODO Job submission command
'''