#!/bin/bash
# installation: follow the latest [verl](https://github.com/volcengine/verl) installation.

set -ex 

export WANDB_API_KEY=
export SWANLAB_API_KEY=
isbaidu=True
workdir=

rewardtype=${rewardtype:-"dapo"}
nothink=${nothink:-"False"}
lr=${lr:-"1e-6"}
useref=${useref:-"False"}
project_name='Qwen2.5-Reproduce'
loss_mode=${lossmode:-"kl_cov"}
entropy=${entropy:-"entropy_regularize"}
basemodelname=${basemodelname:-"math"}
train_prompt_bsz=${rbsz:-"256"} # number of queries 
train_prompt_mini_bsz=${tbsz:-"32"} # number of queries 
ppo_micro=${ppo_micro:-"8"}
n_resp_per_prompt=${nsample:-"8"}
tagname=${tagname:-"none"}
explore=${explore:-"False"}
explore_loss=${explore_loss:-"lower_entropy"}
entropy_coeff=${coeff:-"0.0"}
ase=${ase:-"False"}
exp_name=debug
tagname=$exp_name
adv_estimator=${adv:-"grpo"}
zero=${zero:-"True"}
doc=${doc:-"False"}
use_kl_in_reward=False
exp_name=debug
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=${cliphigh:-"0.28"}

max_prompt_length=10000
max_response_length=${maxlen:-"10000"} # $((512 * 12))
enable_overlong_buffer=False
overlong_buffer_len=512
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"

enable_filter_groups=${filter:-"False"}
filter_groups_metric=acc
max_num_gen_batches=3

gen_prompt_bsz=$((train_prompt_bsz * 1))
val_batch_size=$((train_prompt_bsz * 2))
if [ -v resume ]; then 
  resume_mode=resume_path
else 
  resume=none
  resume_mode=disable
fi 

max_token=${maxtoken:-"32768"}
min_pixels=3136
max_pixels=401408 # 512*28*28
sample_rate=0.3

export TRAIN_FILE=
export TEST_FILE=
export CKPTS_DIR=${workdir}/outputs/${exp_name}
MODEL_PATH=${MODEL_PATH:-"${workdir}/mimo_vl_sft_2508"}
export RAY_DATA_HOME=${workdir}/verl
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"}
CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"}
TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"}
TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]}

# Algorithm
temperature=${temp:-"1.0"}
top_p=${top_p:-"1.0"}
top_k=${top_k:-"-1"} # 0 for HF rollout, -1 for vLLM rollout
ppo_kl_coef=1
k_percent=0.2

# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False

find_interface() {
  local ip_output=$(ip addr show | head -n 10) # Limit to first 10 lines
  local selected_interface=""

  # Debug output (can be removed in final version)
  # echo "--- First 10 lines of ip addr show output: ---"
  # echo "$ip_output"
  # echo "--- End of ip addr show output ---"

  while IFS= read -r line; do
    # Debug output (can be removed in final version)
    # echo "Processing line: $line"

    if [[ "$line" =~ ^[0-9]+:\ ([^:]+):\ \<.*UP.*\> ]]; then
      local interface_name="${BASH_REMATCH[1]}"
      # Debug output (can be removed in final version)
      # echo "  Interface found: $interface_name"
      local interface_up=true
      local is_loopback=false

      if [[ "$interface_name" == "lo" ]]; then
        is_loopback=true
        # Debug output (can be removed in final version)
        # echo "  Interface '$interface_name' is loopback. Skipping."
      fi

      if $is_loopback; then
        continue # Skip loopback interface
      fi

      # Look for inet lines within this interface block
      while IFS= read -r subnet_line; do
        # Debug output (can be removed in final version)
        # echo "  Processing subnet line: $subnet_line"
        if [[ "$subnet_line" =~ inet\ ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)/([0-9]+)\ .*scope\ ([^ ]+) ]]; then
          local ip_address="${BASH_REMATCH[1]}"
          local scope="${BASH_REMATCH[3]}"
          # Debug output (can be removed in final version)
          # echo "    Found inet line: IP Address: $ip_address, Scope: $scope"

          # Exclude loopback IPs and docker0/bridge related IPs by IP range
          if [[ "$ip_address" =~ ^127\. ]]; then
            # Debug output (can be removed in final version)
            # echo "      IP '$ip_address' is loopback. Skipping."
            continue # Skip 127.0.0.0/8 loopback IPs (although 'lo' should already be skipped)
          elif [[ "$ip_address" =~ ^169\.254\. ]]; then
            # Debug output (can be removed in final version)
            # echo "      IP '$ip_address' is link-local (169.254.x.x). Skipping."
            continue # Skip 169.254.0.0/16 link-local IPs (like docker0 often has)
          fi

          local is_private_ip=false
          if [[ "$ip_address" =~ ^10\.([0-9]{1,3}\.){2}[0-9]{1,3}$ ]] ||
             [[ "$ip_address" =~ ^172\.(1[6-9]|2[0-9]|3[0-1])\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]] ||
             [[ "$ip_address" =~ ^192\.168\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]]; then
            is_private_ip=true
            # Debug output (can be removed in final version)
            # echo "      IP '$ip_address' is a private IP."
          # else
            # Debug output (can be removed in final version)
            # echo "      IP '$ip_address' is NOT a private IP."
          fi

          if $is_private_ip || [[ "$scope" == "global" ]]; then # Consider private or global scope interfaces
            selected_interface="$interface_name"
            # Debug output (can be removed in final version)
            # echo "      Interface '$interface_name' with IP '$ip_address' and scope '$scope' is selected."
            # echo "export GLOO_SOCKET_IFNAME=$selected_interface"
            # exit 0 # Exit immediately after finding the first suitable interface for debugging (removed for function)
            break 2 # Found a suitable interface! Break out of both inner and outer loops
          # else
            # Debug output (can be removed in final version)
            # echo "      Interface '$interface_name' with IP '$ip_address' and scope '$scope' is NOT suitable (not private or global)."
          fi
        fi
      done < <(echo "$ip_output" | sed -n "/$interface_name: /,/^[0-9]\+:/p" | sed '$d' ) # Extract lines belonging to current interface block
      if [[ -n "$selected_interface" ]]; then # Check if selected_interface is not empty, if so, interface found and loops broken.
          # Debug output (can be removed in final version)
          # echo "      Selected interface '$selected_interface' already found. Breaking outer loop."
          break # Already found and assigned an interface, break outer loop as well.
      fi
    # else
      # Debug output (can be removed in final version)
      # echo "  Line does not match interface pattern."
    fi
  done < <(echo "$ip_output")

  if [[ -n "$selected_interface" ]]; then
    echo "$selected_interface"
  else
    echo "" # Return empty string if no interface is found, so export GLOO_SOCKET_IFNAME=  (empty)
    # echo "No suitable network interface could be automatically identified for GLOO_SOCKET_IFNAME." # No longer print error message to stderr in function context
    # return 1 # Optionally, you could return a non-zero exit code if you need to check for failure.
  fi
}

export RAY_MASTER_NODE_ADDRESS=${myvar[(($WORLD_SIZE-1))]}
export RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-40000)

export NCCL_NET_PLUGIN=none
export NCCL_IB_TIMEOUT=22
export NCCL_IB_RETRY_CNT=15
export NCCL_DEBUG=INFO


export HOST_IP=0.0.0.0
export VLLM_HOST_IP=0.0.0.0
export HYDRA_FULL_ERROR=1
export RAY_IGNORE_UNHANDLED_ERRORS=1
cd ${workdir}/

WORKING_DIR=${WORKING_DIR:-"${PWD}"}
export HF_ENDPOINT=https://hf-mirror.com
# export WANDB_API_KEY="0a7c185570b683512cc61d2209e91a952eee0ad9"
export NNODES=$WORLD_SIZE
nnode=$NNODES
echo "rank $RANK"
if [ "$RANK" = "0" ]; then
    # Start Ray head node and capture the output
    ray_output=$(ray start --head --num-gpus 8)

    # Extract the IP address using grep and sed
    ip_address=$(echo "$ray_output" | grep -oP "ray start --address='\K[^']+")

    # Write the extracted IP address to a file named "ip.txt"
    mkdir -p ${workdir}/ip_tmp
    echo "$ip_address" > ${workdir}/ip_tmp/ip_${tagname}.txt
    cat ${workdir}/ip_tmp/ip_${tagname}.txt

    export RAY_ADDRESS=$ip_address #"http://localhost:8265"

    sleep 15
    MAX_RETRIES=600
    COUNT=0
    while [[ $COUNT -lt $MAX_RETRIES ]]; do
        ACTIVE_NODES=$(ray status | sed -n '/Active:/,/Pending:/p'|grep "node_" | wc -l)
        if [ "$ACTIVE_NODES" -ge "$NNODES" ]; then
            break
        fi
        echo "Master Waiting for Ray cluster to be ready... Attempt $((COUNT+1))"
        sleep 1
        COUNT=$((COUNT+1))
    done
    ray status
    
    if [ -v hf_resume ]; then 
      export MODEL_PATH=${hf_resume}/hfmodel
      export resume=${hf_resume}
      if [ ! -e "$MODEL_PATH" ]; then
          echo "start model conversion"
          bash ${workdir}/convert/convert.sh
      fi
    fi 
    

    pip install nltk -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install jieba -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install latex2sympy2_extended -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install math_verify -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install pyahocorasick -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install -U datasets  -i https://pypi.tuna.tsinghua.edu.cn/simple
    # pip install -U transformers -i https://pypi.tuna.tsinghua.edu.cn/simple
    
    python -m recipe.dapo.main_dapo \
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.filter_overlong_prompts=False \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.gen_batch_size=${gen_prompt_bsz} \
    data.train_batch_size=${train_prompt_bsz} \
    data.val_batch_size=${val_batch_size} \
    data.return_raw_chat=True \
    data.shuffle=False \
    data.zero=${zero} \
    +data.doc=${doc} \
    +data.explore=$explore \
    +data.min_pixels=${min_pixels} \
    +data.max_pixels=${max_pixels} \
    +data.sample_rate=1.0 \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.rollout.max_model_len=32768 \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.actor.loss_mode=${loss_mode} \
    actor_rollout_ref.actor.k_percent=${k_percent} \
    actor_rollout_ref.actor.ppo_kl_coef=${ppo_kl_coef} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro} \
    actor_rollout_ref.rollout.mode=sync \
    +algorithm.bonus_ratio=0.2 \
    +algorithm.sse=${sse} \
    +algorithm.ase=${ase} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    algorithm.filter_groups.enable=${enable_filter_groups} \
    algorithm.filter_groups.metric=${filter_groups_metric} \
    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    +actor_rollout_ref.actor.use_ref=False \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=${gradckpt:-"True"} \
    actor_rollout_ref.actor.optim.lr=${lr} \
    actor_rollout_ref.actor.optim.weight_decay=0 \
    actor_rollout_ref.actor.optim.warmup_style=constant \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \
    +actor_rollout_ref.actor.explore_loss=none \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization:-"0.85"} \
    actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k="${top_k}" \
    actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
    actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
    reward_model.reward_manager=${rewardtype} \
    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
    reward_model.overlong_buffer.len=${overlong_buffer_len} \
    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
    trainer.logger=['console','wandb'] \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes="${NNODES}" \
    trainer.val_before_train=True \
    trainer.max_actor_ckpt_to_keep=3 \
    trainer.test_freq=10 \
    trainer.save_freq=40 \
    trainer.total_epochs=2 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    +trainer.entropy_minimization=${entropy} \
    +trainer.val_only=${valonly} \
    trainer.resume_mode=${resume_mode} \
    trainer.resume_from_path=${resume} \
    trainer.validation_data_dir=${CKPTS_DIR}/val 


else 
    set -x
    sleep 5 
    echo "for baidu cluster"
    # Read the IP address from the file and assign it to the variable "head_ip"
    head_ip=$(cat ${workdir}/ip_tmp/ip_${tagname}.txt)
    # gloo=$(cat ${workdir}/ip_tmp/gloo_${tagname}.txt)
    # export GLOO_SOCKET_IFNAME=$gloo
    # echo "gloo: $GLOO_SOCKET_IFNAME"
    # Print the value of head_ip for verification
    echo "Head IP Address: $head_ip"
    ray start --address ${head_ip} --num-gpus 8 --block

    
    echo $HOST_IP

    pip install nltk -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install jieba -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install latex2sympy2_extended -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install math_verify -i https://pypi.tuna.tsinghua.edu.cn/simple
    # pip install -U transformers -i https://pypi.tuna.tsinghua.edu.cn/simple

fi
