#!/bin/bash
#SBATCH -J leanDPG
#SBATCH -p dpg
#SBATCH -A dpg
#SBATCH -N 1
##SBATCH -A dpg
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=20
#SBATCH --mem=200G
#SBATCH --time=15-00:00:00
#SBATCH --constraint="gpu_40g+"
#SBATCH --output=/beegfs/scratch/user/<anonymized>/experiments/lean/logs/%j.log
#SBATCH --error=/beegfs/scratch/user/<anonymized>/experiments/lean/logs/%j.err

source ~/miniconda3/bin/activate
scontrol show job ${SLURM_JOB_ID}   
nvidia-smi
nvidia-smi topo -m

source ~/.bashrc
conda activate verl


# can make training faster, depends on infrastructure
export NCCL_IBEXT_DISABLE=1
export NCCL_NVLS_ENABLE=1
export NCCL_IB_HCA=mlx5
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1


#export MLFLOW_TRACKING_URI="http://carvi.int.europe.naverlabs.com:3030/"
#export TENSORBOARD_DIR="/beegfs/scratch/user/<anonymized>/experiments/lean/logs/tensorboard"
#export MLFLOW_EXPERIMENT_ID="10"
#export MLFLOW_TRACKING_URI="http://frontend-nle002.kr.europe.naverlabs.com:8080/"
export MLFLOW_TRACKING_URI="http://proxy-carvi.int.europe.naverlabs.com:80/"
export LD_LIBRARY_PATH=$HOME/miniconda3/lib/:$LD_LIBRARY_PATH

MODEL=/beegfs/scratch/user/<anonymized>/fcdpg-verl/DeepSeek-Prover-V1.5-SFT
PROMPT_KEY=deepseek-prover
# DATASET=~/projects/verl/data/mff-lwb-10k-seen.parquet
DATASET=/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/results/DeepSeek-Prover-V1.5-SFT/mff-lwb-10k-seen-verified-scored.parquet
#/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/data/processed_cot/mff-lwb-10k-seen.parquet
TEST_DATASET=/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/data/processed_cot/mff-lwb-unseen-200.parquet

PROJECT_NAME='LeanProver_FCDPG'
export HYDRA_FULL_ERROR=1

# turn this on for A6000s or L40s that dont have NVLINK
export NCCL_P2P_DISABLE=0
#export VLLM_ATTENTION_BACKEND=XFORMERS
export VLLM_USE_V1=0
export RAY_memory_monitor_refresh_ms=0
BATCH_SIZE=128
N=4
DIVERGENCE=amari_alpha
ALPHA=0.01
EXPERIMENT_NAME="DeepSeek-Prover-1.5-7B-10k-${DIVERGENCE}-${ALPHA}-${BATCH_SIZE}-${N}-fixed"


PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
  trainer.project_name=${PROJECT_NAME} \
  trainer.experiment_name=${EXPERIMENT_NAME} \
  algorithm.adv_estimator=fcdpg \
  algorithm.fcdpg.z_default=0.0001 \
  data.train_files=${DATASET} \
  data.val_files=${TEST_DATASET} \
  data.max_prompt_length=1024 \
  data.max_response_length=1024 \
  data.filter_overlong_prompts=true \
  data.prompt_key=prompt \
  data.train_batch_size=${BATCH_SIZE} \
  data.truncation='error' \
  +data.seed=42 \
  algorithm.fcdpg.exponential_ebm=false \
  algorithm.fcdpg.use_baseline=true \
  algorithm.fcdpg.alpha=${ALPHA} \
  algorithm.fcdpg.reset_z=true \
  algorithm.fcdpg.ir_max_clip=100 \
  algorithm.fcdpg.baseline_window_size=1024 \
  algorithm.fcdpg.loss_divergence=${DIVERGENCE} \
  actor_rollout_ref.actor.policy_loss.loss_mode=vanilla \
  actor_rollout_ref.model.path=${MODEL} \
  actor_rollout_ref.actor.optim.lr=3e-6 \
  actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \
  actor_rollout_ref.actor.optim.weight_decay=0.0 \
  actor_rollout_ref.model.use_remove_padding=True \
  actor_rollout_ref.actor.ppo_mini_batch_size=${BATCH_SIZE} \
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
  actor_rollout_ref.model.enable_gradient_checkpointing=True \
  actor_rollout_ref.actor.fsdp_config.param_offload=False \
  actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
  actor_rollout_ref.actor.grad_clip=1.0 \
  actor_rollout_ref.actor.clip_ratio=0.2 \
  actor_rollout_ref.actor.entropy_coeff=0.0 \
  actor_rollout_ref.actor.use_kl_loss=False \
  actor_rollout_ref.actor.kl_loss_coef=0.001 \
  actor_rollout_ref.actor.kl_loss_type=low_var_kl \
  actor_rollout_ref.actor.entropy_coeff=0 \
  actor_rollout_ref.actor.ppo_epochs=1 \
  actor_rollout_ref.actor.use_torch_compile=false \
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
  actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
  actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
  actor_rollout_ref.rollout.name=vllm \
  actor_rollout_ref.rollout.n=${N} \
  actor_rollout_ref.rollout.response_length=1024 \
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
  algorithm.gamma=1.0 \
  algorithm.lam=1.0 \
  actor_rollout_ref.ref.fsdp_config.param_offload=True \
  actor_rollout_ref.model.trust_remote_code=True \
  trainer.default_hdfs_dir=null \
  trainer.nnodes=1 \
  trainer.n_gpus_per_node=${SLURM_GPUS_ON_NODE} \
  trainer.logger='["console","mlflow"]' \
  trainer.save_freq=50 \
  trainer.max_actor_ckpt_to_keep=1 \
  trainer.val_before_train=False \
  trainer.total_epochs=5 \
  critic.ppo_micro_batch_size_per_gpu=16 \
  critic.grad_clip=1.0 \
  custom_reward_function.path=/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/lean/verifier.py \
  custom_reward_function.name=verify_with_deepseek_verifier \
  reward_model.reward_manager=prime \
  reward_model.launch_reward_fn_async=True \
  ray_init.num_cpus=${SLURM_CPUS_PER_TASK} \
  actor_rollout_ref.rollout.enforce_eager=False \
  actor_rollout_ref.rollout.free_cache_engine=True \
  actor_rollout_ref.model.use_liger=False \
  actor_rollout_ref.model.use_fused_kernels=True \
  actor_rollout_ref.model.fused_kernel_options.impl_backend=torch \
  actor_rollout_ref.rollout.val_kwargs.top_k=50 \
  actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \
  actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
  actor_rollout_ref.rollout.val_kwargs.do_sample=True \
  actor_rollout_ref.rollout.val_kwargs.n=16 \
  trainer.test_freq=500 \


  #actor_rollout_ref.model.custom_chat_template='"<｜begin▁of▁sentence｜>{{ messages[0]['content'] }}<｜begin▁of▁sentence｜>"' \

  

  # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.
  # Default is naive. If all verification functions are multiprocessing-safe,
  # the reward manager can be set to prime for parallel verification.
# fcdpg

#  lean.prompt_key=${PROMPT_KEY} \
#  lean.num_samples=32 \
#  lean.problem_batch_size=16 \
#  lean.rejection_sampling=False \
#  lean.advantage_threshold=True \
#  lean.max_workers=${MAX_WORKERS} \

#  +trainer.sample_only=True
