#!/bin/bash
set -x

export NCCL_IB_TC=136
export NCCL_IB_SL=5
export NCCL_IB_GID_INDEX=3
export NCCL_SOCKET_IFNAME=eth0
export NCCL_DEBUG=INFO
export NCCL_IB_HCA=mlx5
export NCCL_IB_TIMEOUT=22
export NCCL_IB_QPS_PER_CONNECTION=8
export NCCL_NET_PLUGIN=none

export TORCH_CUDA_ARCH_LIST="9.0"

# Ray runtime env
export RAY_raylet_start_wait_time_s=6000                 # 默认10秒
export RAY_worker_register_timeout_seconds=60           # 默认30秒


export HYDRA_FULL_ERROR=1

# uniform eval
export UNIFORM_EVAL_NO_BAN_MODULES=1
export UNIFORM_EVAL_PREPROCESS=multi_agent 
export COMPUTE_SCORE_CONCURRENT_NUMBER=512
export EVAL_LLM_URL=http://172.22.9.65:10080/v1
export EVAL_LLM_KEY=EMPTY
export EVAL_LLM_NAME=qwen2.5-72b-instruct


# qwen-agent
export QWEN_SEARCH_ENABLE_CSI=false
export QWEN_IDP_ENABLE_CSI=false
export SPECIAL_CODE_MODE=false
export QWEN_DOC_PARSER_USE_IDP=false
export MAX_TOOL_CALL_WORKERS=20


## cache
export QWEN_AGENT_DEFAULT_WORKSPACE=path/nas_cache/workspace
## google engine
export QWEN_SEARCH_URL=URL
export QWEN_SEARCH_KEY=KEY
export QWEN_SEARCH_SCENE=SCENE
export GOOGLE_TOPK=10
export GOOGLE_USERNAME=USERNAME
export GOOGLE_MAXPAGELENGTH=4000
export GOOGLE_CONCURRENCY=10
## google scholar
export SCHOLAR_SEARCH_URL=URL
export SCHOLAR_SEARCH_KEY=KEY
export SCHOLAR_SEARCH_CONCURRENCY=10
export SCHOLAR_NUM=5
export ONLY_SCHOLAR=False
export READPAGE=True
## JINA
export JINA_URL=https://r.jina.ai/
export JINA_KEY=KEY
export JINA_CONCURRENCY=20

export WANDB_API_KEY=KEY

timestamp=$(date "+%Y%m%d_%H%M%S")

# dsw 传入参数
export GPUS_PER_NODE=${1:-${MLP_WORKER_GPU:-${KUBERNETES_CONTAINER_RESOURCE_GPU:-8}}} 
export NNODES=${2:-${MLP_WORKER_NUM:-${WORLD_SIZE:-1}}}
export NODE_RANK=${3:-${MLP_WORKER_RACK_RANK_INDEX:-${MLP_ROLE_INDEX:-${RANK:-0}}}}
export MASTER_ADDR=${4:-${MLP_WORKER_0_HOST:-${MASTER_ADDR:-127.0.0.1}}}
# export MASTER_PORT=${5:-${MLP_WORKER_0_PORT:-${MASTER_PORT:-1234}}}
echo "${GPUS_PER_NODE} | ${NNODES} | ${NODE_RANK} | ${MASTER_ADDR}"

GPU_NUM=$((GPUS_PER_NODE * NNODES))

BASE_MODEL=PATH/Qwen2.5-7B-Instruct
TOKENIZER_PATH=${BASE_MODEL}
ckpt_name=$(basename ${BASE_MODEL})

TRAIN_DATA=./data/rl_train_data/rl_train.jsonl
VAL_DATA="['./data/test_data/hle.jsonl','./data/test_data/mhqa.jsonl','./data/test_data/shqa.jsonl']"
dataset_name=rl_train

BATCH_SIZE=32
MINI_BATCH_SIZE=32
N=16
SP_SIZE=2
TP_SIZE=1

export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO
export CUDA_DEVICE_MAX_CONNECTIONS=1
# export VLLM_ATTENTION_BACKEND=XFORMERS
export TP_SIZE=${TP_SIZE}
export TOOL_VERSION=en_v3_code
export TOOL_RESPONSE_TAG=None
export TOOL_CALL_TAG=tool_call  # system2 parse results
export VERBOSE=False
export ANSWER_TAG=answer
export TOOL_ERROR_TOLERANCE=2
export TOLERANCE_TOOLS="PythonInterpreter"
# export PYTHON_CACHE_UUID=$(cat /proc/sys/kernel/random/uuid | tr -d '\n')
export SANDBOX_FUSION_ENDPOINT="http://172.22.4.103:8080"
export CHECK_BOX=False

lr=1e-6
rl_method=grpo
system1_mode=training
beta=2
temperature=1.0
top_p=0.95
adv=True
tool_response_role=tool
system1_max_prompt_length=$((23*1024))
system1_max_token=$((8*1024)) # max_response_length

MAX_PROMPT_LENGTH=$((3*1024))
MAX_RESPONSE_LENGTH=$((28*1024))
MAX_NUM_BATCHED_TOKENS=$((32*1024)) # vllm
PPO_MAX_TOKEN_LEN_PER_GPU=$(( ((system1_max_prompt_length > MAX_PROMPT_LENGTH ? system1_max_prompt_length : MAX_PROMPT_LENGTH) + (system1_max_token > MAX_RESPONSE_LENGTH ? system1_max_token : MAX_RESPONSE_LENGTH)) / SP_SIZE ))

EXPERIMENT_NAME=${dataset_name}_${ckpt_name}_${system1_mode}_G${GPU_NUM}_bsz${BATCH_SIZE}_${MINI_BATCH_SIZE}_beta${beta}_lr${lr}_n${N}_R${READPAGE}_${tool_response_role}_T${temperature}_continue
PROJECT_NAME=search_local
SAVE_DIR=path_to/nas_output_dir/debug/rl/${rl_method}/${system1_mode}/${PROJECT_NAME}/${TOOL_VERSION}/${EXPERIMENT_NAME}

python=/root/miniconda3/envs/mars/bin/python
ray=/root/miniconda3/envs/mars/bin/ray
pip=/root/miniconda3/envs/mars/bin/pip

${pip} install sandbox_fusion -i https://pypi.tuna.tsinghua.edu.cn/simple

${ray} stop

if [ $NODE_RANK -eq 0 ]; then

    ${ray} start --block  --head --port=6379 &
    export RAY_DEBUG=legacy

    mkdir -p ${SAVE_DIR}

    ${python} -m verl.trainer.main_ppo \
        data.train_files=${TRAIN_DATA} \
        data.val_files=${VAL_DATA} \
        data.train_batch_size=${BATCH_SIZE} \
        data.max_prompt_length=${MAX_PROMPT_LENGTH} \
        data.max_response_length=${MAX_RESPONSE_LENGTH} \
        actor_rollout_ref.model.path=${BASE_MODEL} \
        actor_rollout_ref.model.tokenizer=${TOKENIZER_PATH} \
        actor_rollout_ref.model.use_remove_padding=True \
        actor_rollout_ref.model.enable_gradient_checkpointing=True \
        actor_rollout_ref.actor.ppo_mini_batch_size=${MINI_BATCH_SIZE} \
        actor_rollout_ref.actor.use_dynamic_bsz=True \
        actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU} \
        actor_rollout_ref.actor.use_kl_loss=True \
        actor_rollout_ref.actor.kl_loss_coef=0.0 \
        actor_rollout_ref.actor.kl_loss_type=low_var_kl \
        actor_rollout_ref.actor.entropy_coeff=0.0 \
        actor_rollout_ref.actor.ulysses_sequence_parallel_size=${SP_SIZE} \
        actor_rollout_ref.actor.optim.lr=${lr} \
        actor_rollout_ref.actor.fsdp_config.param_offload=False \
        actor_rollout_ref.actor.fsdp_config.grad_offload=False \
        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
        actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
        actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((PPO_MAX_TOKEN_LEN_PER_GPU)) \
        actor_rollout_ref.ref.fsdp_config.param_offload=False \
        actor_rollout_ref.ref.ulysses_sequence_parallel_size=${SP_SIZE} \
        actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
        actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((PPO_MAX_TOKEN_LEN_PER_GPU)) \
        actor_rollout_ref.rollout.name=vllm \
        actor_rollout_ref.rollout.enforce_eager=False \
        actor_rollout_ref.rollout.free_cache_engine=False \
        actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
        actor_rollout_ref.rollout.tensor_model_parallel_size=${TP_SIZE} \
        actor_rollout_ref.rollout.max_num_batched_tokens=${MAX_NUM_BATCHED_TOKENS} \
        actor_rollout_ref.rollout.n=${N} \
        actor_rollout_ref.rollout.temperature=${temperature} \
        actor_rollout_ref.rollout.top_p=${top_p} \
        actor_rollout_ref.rollout.monitor_key=null \
        +actor_rollout_ref.rollout.use_rl_agent=True \
        +actor_rollout_ref.rollout.function_list=['GoogleSearch','GoogleScholar','PythonInterpreter'] \
        actor_rollout_ref.rollout.multi_agent_pattern.tool_external_concurrency=30 \
        actor_rollout_ref.rollout.multi_agent_pattern.beta=${beta} \
        actor_rollout_ref.rollout.multi_agent_pattern.max_depth=10 \
        actor_rollout_ref.rollout.multi_agent_pattern.tool_response_role=${tool_response_role} \
        actor_rollout_ref.rollout.multi_agent_pattern.adv=${adv} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.system1_mode=${system1_mode} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.temperature=${temperature} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.top_p=${top_p} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.max_prompt_length=${system1_max_prompt_length} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.max_tokens=${system1_max_token} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.readpage=${READPAGE} \
        actor_rollout_ref.rollout.multi_agent_pattern.system1_sampling_params.enable_thinking=False \
        'actor_rollout_ref.rollout.multi_agent_pattern.system2_sampling_params.add_prefix="<think>"' \
        actor_rollout_ref.rollout.multi_agent_pattern.system2_sampling_params.stop='["</tool_call>","</answer>"]' \
        actor_rollout_ref.rollout.multi_agent_pattern.system2_sampling_params.enable_thinking=False \
        reward_model.reward_manager=unieval \
        algorithm.adv_estimator=${rl_method} \
        algorithm.kl_ctrl.kl_coef=0.000 \
        trainer.logger=['console','wandb'] \
        +trainer.val_before_train=True \
        trainer.n_gpus_per_node=${GPUS_PER_NODE} \
        trainer.nnodes=${NNODES} \
        trainer.save_freq=2 \
        trainer.test_freq=2 \
        trainer.project_name=${PROJECT_NAME} \
        trainer.experiment_name=${EXPERIMENT_NAME} \
        trainer.default_local_dir=${SAVE_DIR} \
        hydra.run.dir=${SAVE_DIR} \
        trainer.total_epochs=50 2>&1 | tee ${SAVE_DIR}/log_${timestamp}.log

else
    sleep 60
    
    ${ray} start --block  --address=${MASTER_ADDR}:6379

fi
