set -euo pipefail

# Set your local TRL directory as well as TRL training environment and evaluation environment path
TRL_DIR="../trl"
TRL_VENV="../trl/env/bin/activate"
EVAL_VENV="../eval/env/bin/activate"

# Set your local model path as well as the training and evaluation dataset path
BASE_MODEL="" # your local model path
TRAIN_DATASET="" # your local training dataset path
VAL_DATASET="" # your local evaluation dataset path

# If your are working on safety alignment task, please fill the following paths
BEAVER_REWARD_MODEL_DIR="" # your local reward model path
BEAVER_COST_MODEL_DIR="" # your local cost model path
BEAVER_EVAL_PROMPTS="" # your local evaluation dataset path

# If your are working on summary task, please fill the following paths
SUMMARY_QUALITY_MODEL_DIR="" # your local quality model path
SUMMARY_FAITHFUL_MODEL_DIR="" # your local faithful model path
EVAL_PROMPTS="" # your local evaluation dataset path

# Eval mode:
# - summary: uses eval/new_score.py (quality + faithful judges on RedditSummary prompts)
# - beaver:  uses eval/new_score_beaver.py (beaver reward + cost judges on BeaverTails prompts)
EVAL_MODE="${EVAL_MODE:-beaver}"

OUTPUT_ROOT="" # your local output path
LOG_DIR="" # your local log path
EVAL_OUTPUT_DIR="" # your local eval output path (generations + scores)

mkdir -p "$LOG_DIR" "$EVAL_OUTPUT_DIR" "$OUTPUT_ROOT"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # set your local GPU devices

# Training defaults (edit if needed)
MAX_LENGTH=2048
TRAIN_BS=2
GRAD_ACCUM=4
WARMUP_RATIO="0.1"
SCHEDULER="cosine"

# RACO defaults (edit if needed)
# NOTE: this is now a sweep variable (see DEFAULT_CS / CS_OVERRIDE below).
RACO_C="0.4"
RACO_USE_CAGRAD="True" # Whether activate CAGrad, set it as false would be DPO LW
RACO_CLIP_LAMBDA="True" # Whether activate CAGrad-Clip
LENGTH_NORMALIZED="False"

# Default sweep weights + LRs + Cs + clip lambdas
DEFAULT_WEIGHTS=("0.8,0.2" "0.5,0.5" "0.2,0.8")
#DEFAULT_WEIGHTS=("0.35,0.65" "0.65,0.35")
DEFAULT_LRS=("1e-5")
# Default Cs (override via env: CS_OVERRIDE="0.2 0.4 0.6")
DEFAULT_CS=("0.5" "0.25")
# Default clip lambdas (override via env: CLIP_LAMBDAS_OVERRIDE="True False")
DEFAULT_CLIP_LAMBDAS=("True")

to_tag() {
  # Make a string safe for filenames (keep it readable).
  echo "$1" | tr '/ :,' '____'
}

parse_weights() {
  # Input: "wq,wv" or "wq"
  # Output: prints "wq wv"
  local spec="$1"
  if [[ "$spec" == *","* ]]; then
    local wq="${spec%%,*}"
    local wv="${spec#*,}"
    echo "$wq" "$wv"
  else
    local wq="$spec"
    # wv := 1 - wq (bash arithmetic doesn't do floats; use python3)
    local wv
    wv="$(python3 - <<PY
wq=float("$wq")
print(f"{1.0-wq:.10g}")
PY
)"
    echo "$wq" "$wv"
  fi
}

run_one() {
  local spec="$1"
  local lr="$2"
  local raco_c="$3"
  local clip_lambda="$4"
  local port="$5"

  local wq wv
  read -r wq wv < <(parse_weights "$spec")

  local tag
  tag="$(to_tag "wq${wq}_wv${wv}_lr${lr}_c${raco_c}_clip${clip_lambda}_ln${LENGTH_NORMALIZED}")"

  local output_dir="${OUTPUT_ROOT}/raco-${tag}"
  local train_out="${LOG_DIR}/unclip-${tag}-train.out"
  local train_err="${LOG_DIR}/unclip-${tag}-train.err"
  local eval_out="${LOG_DIR}/unclip-${tag}-eval.out"
  local eval_err="${LOG_DIR}/unclip-${tag}-eval.err"

  # Train (blocking). Run in a subshell so venv activation can't leak into later runs.
  (
    cd "$TRL_DIR"
    source "$TRL_VENV"

    # NOTE: NOHUP_TRAIN=1 helps runs survive an SSH disconnect by making the *training*
    # process ignore SIGHUP. Still recommended: run the whole sweep under tmux/screen.
    train_cmd=(env PYTHONPATH=. accelerate launch
      --main_process_port 0
      --config_file ../script/multi_gpu.yaml
      --num_processes 8
      scripts/train_raco.py
      --mode raco
      --model_name_or_path "$BASE_MODEL"
      --dataset_path "$TRAIN_DATASET"
      --val_dataset_path "$VAL_DATASET"
      --output_dir "$output_dir"
      --max_length "$MAX_LENGTH"
      --per_device_train_batch_size "$TRAIN_BS"
      --gradient_accumulation_steps "$GRAD_ACCUM"
      --learning_rate "$lr"
      --logging_steps 1
      --bf16 True
      --raco True
      --raco_weights "${wq},${wv}"
      --raco_c "$raco_c"
      --raco_use_cagrad "$RACO_USE_CAGRAD"
      --per_device_eval_batch_size 8
      --eval_strategy steps
      --eval_steps 100
      --report_to none
      --run_name "raco/wq=${wq},wv=${wv},lr=${lr},ln=${LENGTH_NORMALIZED}"
      --warmup_ratio "$WARMUP_RATIO"
      --length_normalized "$LENGTH_NORMALIZED"
      --raco_clip_lambda "$clip_lambda"
      --lr_scheduler_type "$SCHEDULER"
    )

    if [[ "${NOHUP_TRAIN:-0}" == "1" ]]; then
      nohup "${train_cmd[@]}" >"$train_out" 2>"$train_err" < /dev/null
    else
      "${train_cmd[@]}" >"$train_out" 2>"$train_err"
    fi
  )

  if [[ "${SKIP_EVAL:-0}" == "1" ]]; then
    echo "SKIP_EVAL=1 set; skipping eval for ${output_dir}"
    return
  fi

  # Eval (blocking).
  (
    source "$EVAL_VENV"
    if [[ "$EVAL_MODE" == "beaver" ]]; then
      # Beaver workflow: generate on BeaverTails prompts, then score reward+cost.
      #
      # NOTE: vllm_generate.py expects `--model` to be the *served model name*.
      # In this repo we use the model path as the served model name for simplicity.
      # Stop tokens:
      # - Qwen3:  "<|im_end|>"
      # - Gemma3: "<end_of_turn>"
      # - Llama3: "<|eot_id|>"
      python "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/../eval/new_score_beaver.py" \
        --reward_model_dir "$BEAVER_REWARD_MODEL_DIR" \
        --cost_model_dir "$BEAVER_COST_MODEL_DIR" \
        --score_device cuda \
        --score_fp16 \
        --score_batch_size 8 \
        --score_max_length 4096 \
        -- \
        --output_dir "$EVAL_OUTPUT_DIR" \
        --start_server \
        --server_model_path "$output_dir" \
        --served_model_name "$output_dir" \
        --server_bind_host 0.0.0.0 \
        --server_ready_host 127.0.0.1 \
        --host 127.0.0.1 \
        --port "$port" \
        --tp 8 \
        --server_max_model_len 4096 \
        --model "$output_dir" \
        --input_parquet "$BEAVER_EVAL_PROMPTS" \
        --prompt_column prompt \
        --concurrency 64 \
        --max_tokens 2048 \
        --temperature 0.6 \
        --stop "<|im_end|>" \
        >"$eval_out" 2>"$eval_err"
    else
      # Summarization workflow: generate on RedditSummary prompts, then score quality+faithful.
      # Stop tokens:
      # - Qwen3:  "<|im_end|>"
      # - Gemma3: "<end_of_turn>"
      # - Llama3: "<|eot_id|>"
      python "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/../eval/new_score.py" \
        --quality_model_dir "$SUMMARY_QUALITY_MODEL_DIR" \
        --faithful_model_dir "$SUMMARY_FAITHFUL_MODEL_DIR" \
        --quality_score_mode input_output \
        --score_device cuda \
        --score_fp16 \
        --score_batch_size 128 \
        --score_max_length 1024 \
        -- \
        --output_dir "$EVAL_OUTPUT_DIR" \
        --start_server \
        --server_model_path "$output_dir" \
        --served_model_name "$output_dir" \
        --server_bind_host 0.0.0.0 \
        --server_ready_host 127.0.0.1 \
        --host 127.0.0.1 \
        --port "$port" \
        --tp 8 \
        --server_max_model_len 3072 \
        --model "$output_dir" \
        --input_parquet "$EVAL_PROMPTS" \
        --prompt_column prompt \
        --concurrency 64 \
        --max_tokens 512 \
        --temperature 0.6 \
        --stop "<|im_end|>" \
        >"$eval_out" 2>"$eval_err"
    fi
  )
}

# Weights from CLI (optional)
WEIGHTS=("$@")
if [ "${#WEIGHTS[@]}" -eq 0 ]; then
  WEIGHTS=("${DEFAULT_WEIGHTS[@]}")
fi

# LRs from env override (optional)
LRS=("${DEFAULT_LRS[@]}")
if [[ -n "${LRS_OVERRIDE:-}" ]]; then
  read -r -a LRS <<<"${LRS_OVERRIDE}"
fi

# Cs from env override (optional)
CS=("${DEFAULT_CS[@]}")
if [[ -n "${CS_OVERRIDE:-}" ]]; then
  read -r -a CS <<<"${CS_OVERRIDE}"
fi

# Clip lambdas from env override (optional)
CLIP_LAMBDAS=("${DEFAULT_CLIP_LAMBDAS[@]}")
if [[ -n "${CLIP_LAMBDAS_OVERRIDE:-}" ]]; then
  read -r -a CLIP_LAMBDAS <<<"${CLIP_LAMBDAS_OVERRIDE}"
fi

# Start ports at 8200 to avoid collisions with other scripts using 8000+.
port=8200
for raco_c in "${CS[@]}"; do
  for lr in "${LRS[@]}"; do
    for clip_lambda in "${CLIP_LAMBDAS[@]}"; do
      for spec in "${WEIGHTS[@]}"; do
        run_one "$spec" "$lr" "$raco_c" "$clip_lambda" "$port"
        port=$((port + 1))
      done
    done
  done
done
