#!/bin/bash
#SBATCH -J leanRL
#SBATCH -p gpu
#SBATCH -N 1
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=20
#SBATCH --mem=200G
#SBATCH --time=15-00:00:00
#SBATCH --constraint="gpu_40g+"
#SBATCH --output=/beegfs/scratch/user/<anonymized>/experiments/lean/logs/%j.log
#SBATCH --error=/beegfs/scratch/user/<anonymized>/experiments/lean/logs/%j.err

source ~/miniconda3/bin/activate
scontrol show job ${SLURM_JOB_ID}   
nvidia-smi
nvidia-smi topo -m

source ~/.bashrc
conda activate verl

# can make training faster, depends on infrastructure
export NCCL_IBEXT_DISABLE=1
export NCCL_NVLS_ENABLE=1
export NCCL_IB_HCA=mlx5
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
export MLFLOW_TRACKING_URI=http://proxy-carvi.int.europe.naverlabs.com:80/ 
export LD_LIBRARY_PATH=$HOME/miniconda3/lib/:$LD_LIBRARY_PATH
export TENSORBOARD_DIR="/beegfs/scratch/user/<anonymized>/experiments/lean/logs/tensorboard"
MODEL=/beegfs/scratch/user/<anonymized>/fcdpg-verl/DeepSeek-Prover-V1.5-SFT
PROMPT_KEY=deepseek-prover
# DATASET=~/projects/verl/data/mff-lwb-10k-seen.parquet
DATASET=/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/data/processed_cot/mff-lwb-10k-seen.parquet
TEST_DATASET=/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/data/processed_cot/mff-lwb-unseen-200.parquet

PROJECT_NAME='LeanProver_FCDPG'

# turn this on for A6000s or L40s that dont have NVLINK
export NCCL_P2P_DISABLE=0
#export VLLM_ATTENTION_BACKEND=XFORMERS
#export VERL_LOGGING_LEVEL=DEBUG
export VLLM_USE_V1=1
export RAY_memory_monitor_refresh_ms=0

BATCH_SIZE=128
N=8
EXPERIMENT_NAME="DeepSeek-Prover-1.5-7B-10k-dr_grpo-${BATCH_SIZE}-${N}"


PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
  trainer.project_name=${PROJECT_NAME} \
  trainer.experiment_name=${EXPERIMENT_NAME} \
  algorithm.adv_estimator=grpo \
  algorithm.norm_adv_by_std_in_grpo=false \
  loss_agg_mode=seq-mean-token-sum-norm \
  data.train_files=${DATASET} \
  data.val_files=${TEST_DATASET} \
  actor_rollout_ref.actor.strategy="fsdp2" \
  data.max_prompt_length=1024 \
  data.max_response_length=512 \
  data.filter_overlong_prompts=true \
  data.prompt_key=prompt \
  data.truncation='error' \
  data.train_batch_size=${BATCH_SIZE} \
  +data.seed=42 \
  actor_rollout_ref.model.path=${MODEL} \
  actor_rollout_ref.actor.optim.lr=3e-6 \
  actor_rollout_ref.actor.use_dynamic_bsz=False \
  actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.0 \
  actor_rollout_ref.actor.optim.weight_decay=0.0 \
  actor_rollout_ref.model.use_remove_padding=True \
  actor_rollout_ref.actor.ppo_mini_batch_size=${BATCH_SIZE} \
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
  actor_rollout_ref.model.enable_gradient_checkpointing=True \
  actor_rollout_ref.actor.fsdp_config.param_offload=False \
  actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
  actor_rollout_ref.actor.grad_clip=1.0 \
  actor_rollout_ref.actor.clip_ratio=0.2 \
  actor_rollout_ref.actor.entropy_coeff=0.0 \
  actor_rollout_ref.actor.use_kl_loss=True \
  actor_rollout_ref.actor.kl_loss_coef=0.01 \
  actor_rollout_ref.actor.kl_loss_type=low_var_kl \
  actor_rollout_ref.actor.ppo_epochs=1 \
  actor_rollout_ref.actor.use_torch_compile=True \
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
  actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
  actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
  actor_rollout_ref.rollout.name=vllm \
  actor_rollout_ref.rollout.n=${N} \
  actor_rollout_ref.rollout.disable_log_stats=False \
  actor_rollout_ref.rollout.response_length=512 \
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
  algorithm.gamma=1.0 \
  algorithm.lam=1.0 \
  actor_rollout_ref.ref.fsdp_config.param_offload=False \
  actor_rollout_ref.model.trust_remote_code=True \
  trainer.default_hdfs_dir=null \
  trainer.nnodes=1 \
  trainer.n_gpus_per_node=${SLURM_GPUS_ON_NODE} \
  trainer.logger='["console","mlflow"]' \
  trainer.save_freq=50 \
  trainer.max_actor_ckpt_to_keep=1 \
  trainer.val_before_train=False \
  trainer.total_epochs=5 \
  critic.ppo_micro_batch_size_per_gpu=8 \
  critic.grad_clip=1.0 \
  custom_reward_function.path=/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/lean/verifier.py \
  custom_reward_function.name=verify_with_deepseek_verifier \
  reward_model.reward_manager=prime \
  reward_model.launch_reward_fn_async=True \
  ray_init.num_cpus=${SLURM_CPUS_PER_TASK} \
  actor_rollout_ref.rollout.enforce_eager=False \
  actor_rollout_ref.rollout.free_cache_engine=True \
  actor_rollout_ref.model.use_liger=False \
  actor_rollout_ref.model.use_fused_kernels=True \
  actor_rollout_ref.model.fused_kernel_options.impl_backend=torch \
  actor_rollout_ref.rollout.val_kwargs.top_k=50 \
  actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \
  actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
  actor_rollout_ref.rollout.val_kwargs.do_sample=True \
  actor_rollout_ref.rollout.val_kwargs.n=32 \
  trainer.test_freq=500 \
  #actor_rollout_ref.rollout.enable_chunked_prefill=True \
  #actor_rollout_ref.model.custom_chat_template="{%- for message in messages -%}{%- if message['role'] == 'user' -%}{{ message['content'] }}{%- endif -%}{%- endfor -%}<｜begin of sentence｜>"
  
  # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.
  # Default is naive. If all verification functions are multiprocessing-safe,
  # the reward manager can be set to prime for parallel verification.
# fcdpg
#  lean.prompt_key=${PROMPT_KEY} \
#  lean.num_samples=32 \
#  lean.problem_batch_size=16 \
#  lean.rejection_sampling=False \
#  lean.advantage_threshold=True \
#  lean.max_workers=${MAX_WORKERS} \
#  +trainer.sample_only=True
