#!/bin/bash

set -xeuo pipefail
ulimit -n 65535

# 网络接口检测函数
find_interface() {
  local ip_output=$(ip addr show | head -n 10) # Limit to first 10 lines
  local selected_interface=""

  while IFS= read -r line; do
    if [[ "$line" =~ ^[0-9]+:\ ([^:]+):\ \<.*UP.*\> ]]; then
      local interface_name="${BASH_REMATCH[1]}"
      local interface_up=true
      local is_loopback=false

      if [[ "$interface_name" == "lo" ]]; then
        is_loopback=true
      fi

      if $is_loopback; then
        continue # Skip loopback interface
      fi

      # Look for inet lines within this interface block
      while IFS= read -r subnet_line; do
        if [[ "$subnet_line" =~ inet\ ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)/([0-9]+)\ .*scope\ ([^ ]+) ]]; then
          local ip_address="${BASH_REMATCH[1]}"
          local scope="${BASH_REMATCH[3]}"

          # Exclude loopback IPs and docker0/bridge related IPs by IP range
          if [[ "$ip_address" =~ ^127\. ]]; then
            continue # Skip 127.0.0.0/8 loopback IPs
          elif [[ "$ip_address" =~ ^169\.254\. ]]; then
            continue # Skip 169.254.0.0/16 link-local IPs
          fi

          local is_private_ip=false
          if [[ "$ip_address" =~ ^10\.([0-9]{1,3}\.){2}[0-9]{1,3}$ ]] ||
             [[ "$ip_address" =~ ^172\.(1[6-9]|2[0-9]|3[0-1])\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]] ||
             [[ "$ip_address" =~ ^192\.168\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]]; then
            is_private_ip=true
          fi

          if $is_private_ip || [[ "$scope" == "global" ]]; then
            selected_interface="$interface_name"
            break 2 # Found a suitable interface! Break out of both inner and outer loops
          fi
        fi
      done < <(echo "$ip_output" | sed -n "/$interface_name: /,/^[0-9]\+:/p" | sed '$d' )
      if [[ -n "$selected_interface" ]]; then
          break # Already found and assigned an interface, break outer loop as well.
      fi
    fi
  done < <(echo "$ip_output")

  if [[ -n "$selected_interface" ]]; then
    echo "$selected_interface"
  else
    echo "" # Return empty string if no interface is found
  fi
}

# 多节点配置
if [ -v VC_WORKER_HOSTS ]; then
    # 解析VC_WORKER_HOSTS环境变量
    IFS=','
    read -ra myvar <<< "$VC_WORKER_HOSTS"
    
    echo "Number of elements in the vc worker hosts array: ${#myvar[@]}"
    WORLD_SIZE=${MA_NUM_HOSTS:-"1"}
    export RAY_MASTER_NODE_ADDRESS=${myvar[(($WORLD_SIZE-1))]}
    export RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-40000)
else 
    RAY_MASTER_NODE_ADDRESS="0.0.0.0"
    RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-65535)
    WORLD_SIZE=1
fi

# 网络环境变量设置
export NCCL_SOCKET_IFNAME=ens2f5
export GLOO_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME}
export NCCL_NET_PLUGIN=none
export NCCL_IB_TIMEOUT=22
export NCCL_IB_RETRY_CNT=15
export NCCL_DEBUG=INFO
MASTER_HOST="$VC_WORKER_HOSTS"
MASTER_ADDR="${VC_WORKER_HOSTS%%,*}"
NODE_RANK="${VC_TASK_INDEX:-0}"
GPUS_PER_NODE="${MA_NUM_GPUS:-8}"

export HOST_IP=0.0.0.0
export VLLM_HOST_IP=0.0.0.0
cd /home/ma-user/work/lilong/deepscaler/new_verl/Reinforce-Ada
export SWANLAB_API_KEY='UPRl4xYX8fWFJsWJdCmfy'

export WORKING_DIR="${PWD}"

# Wandb setting
use_multiround_adaptive_downsampling=0
model_name_or_path=${1:-"Qwen3-4B-Base"}
model_name=${2:-'model_name'}
prefix_name=test

use_dynamic_kl=True

sft_loss_coeff=0.02
max_age=8

temperature=1.0
NGPUS=8
train_prompt_bsz=512
train_prompt_mini_bsz=128
sp_size=1
tp_size=1
ppo_micro_batch_size_per_gpu=16
use_dynamic_bsz=True
offload=False

max_prompt_length=$((512 * 1))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 2))
overlong_penalty_factor=0.0
max_response_length=$((6144))

project_name=Reinforce-Ada
MODEL_PATH=${model_name_or_path}
experiment_name=${prefix_name}_${model_name}
exp_name=${experiment_name}

# Output
ckpts_dir="./outputs/${prefix_name}/${experiment_name}"
mkdir -p "${ckpts_dir}/logs"

# Trainig setting



actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))

loss_agg_mode="token-mean"

top_p=1.0
top_k=-1
val_temperature=0.6
n=8
val_top_p=0.95

# Algorithm setting
adv_estimator=grpo
kl_coef=0.0
use_kl_in_reward=False
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
dynamic_kl_batch_size=512


## Reinforce-Ada setting
if [ "$use_multiround_adaptive_downsampling" = "1" ]; then
    multiround_adaptive_downsampling=True    
else
    multiround_adaptive_downsampling=False
fi

reinforce_ada_choice="balanced" # "positive_focused" or "balanced"
global_stat_est=True
norm_adv_by_std_in_grpo=False

# Training data
train_path=/home/ma-user/work/lilong/deepscaler/new_verl/Reinforce-Ada/data/reinforce_ada_hard_prompt/train_format.parquet
test_path=/home/ma-user/work/lilong/deepscaler/new_verl/Reinforce-Ada/data_process/dapo/combined.parquet
train_files="['$train_path']"
test_files="['$test_path']"
export PYTHONPATH="${PYTHONPATH:-/home/ma-user/work/lilong/deepscaler/new_verl/Reinforce-Ada}" # 如果
# Ray集群启动逻辑
if [ "$NODE_RANK" = "0" ]; then
    # 启动Ray主节点
    ray_output=$(ray start --head --num-gpus 8)
    
    # 提取IP地址
    ip_address=$(echo "$ray_output" | grep -oP "ray start --address='\K[^']+")
    
    # 保存IP地址到文件
    mkdir -p ip_tmp
    echo "$ip_address" > ip_tmp/ip_${experiment_name}.txt
    cat ip_tmp/ip_${experiment_name}.txt
    
    # 设置网络接口
    export GLOO_SOCKET_IFNAME=$(find_interface)
    echo "$GLOO_SOCKET_IFNAME" > ip_tmp/gloo_${experiment_name}.txt
    
    sleep 60
    ray status
    
    # 启动训练
    python3 -m verl.trainer.main_ppo \
    data.train_files=${train_files} \
    data.val_files=${test_files} \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    data.dynamic_kl_batch_size=${dynamic_kl_batch_size} \
    data.max_age=${max_age} \
    algorithm.dynamic_kl=${use_dynamic_kl} \
    actor_rollout_ref.rollout.n=${n} \
    algorithm.multiround_adaptive_downsampling=${multiround_adaptive_downsampling} \
    algorithm.reinforce_ada_choice=${reinforce_ada_choice} \
    algorithm.global_stat_est=${global_stat_est} \
    algorithm.norm_adv_by_std_in_grpo=${norm_adv_by_std_in_grpo} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    algorithm.filter_groups.enable=False \
    algorithm.filter_groups.metric='seq_final_reward' \
    algorithm.filter_groups.max_num_gen_batches=10 \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.sft_loss_coeff=${sft_loss_coeff} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.model.use_remove_padding=True \
    +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0. \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.use_torch_compile=False \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${tp_size} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k=${top_k} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=${NGPUS} \
    reward_model.reward_manager=naive \
    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
    trainer.validation_data_dir=/home/ma-user/work/lilong/deepscaler/new_verl/Reinforce-Ada/val_results/${exp_name} \
    trainer.rollout_data_dir=/home/ma-user/work/lilong/deepscaler/new_verl/Reinforce-Ada/rollout_results/${exp_name} \
    trainer.resume_mode="auto" \
    trainer.logger='["console","swanlab"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node="${NGPUS}" \
    trainer.nnodes="${WORLD_SIZE}" \
    trainer.val_before_train=True \
    trainer.val_only=True \
    trainer.test_freq=40 \
    trainer.save_freq=40 \
    trainer.total_epochs=1 \
    trainer.total_training_steps=5 \
    trainer.default_local_dir=${ckpts_dir} \
    trainer.log_val_generations=10
    
else 
    # 工作节点逻辑
    sleep 15 
    # 读取主节点IP地址
    head_ip=$(cat ip_tmp/ip_${experiment_name}.txt)
    gloo=$(cat ip_tmp/gloo_${experiment_name}.txt)
    export GLOO_SOCKET_IFNAME=$gloo
    echo "gloo: $GLOO_SOCKET_IFNAME"
    echo "Head IP Address: $head_ip"
    
    # 连接到Ray集群
    ray start --address ${head_ip}
    echo $HOST_IP
fi                                                                  