#!/usr/bin/env bash
set -euo pipefail

# Edit these values as needed.
MODEL_NAME_OR_PATH="Qwen/Qwen2.5-3B-Instruct"  # Replace with your model path
TRAIN_JSONL="data/train.jsonl"  # Replace with your training data path
EVAL_SPLIT_RATIO=0.05
EVAL_SPLIT_SEED=42
NUM_GPUS=4
ZERO_STAGE=3
DEEPSPEED_CONFIG_PATH="src/deepspeed_zero3.json"
WANDB_PROJECT="drift"
RUN_NAME="scaling"
REPORT_TO="none"
OUTPUT_DIR="./checkpoint/weighted-sft/Qwen2.5-3B-Instruct/${RUN_NAME}"

MAX_COMPLETION_LENGTH=512
MAX_PROMPT_LENGTH=8192
PER_DEVICE_TRAIN_BATCH_SIZE=2
PER_DEVICE_EVAL_BATCH_SIZE=4
GRAD_ACCUM_STEPS=16
LEARNING_RATE=5e-6
NUM_TRAIN_EPOCHS=1
WARMUP_RATIO=0.1
GAMMA=0.9
BETA=0.1
LOGGING_STEPS=1
EVAL_STEPS=20
SAVE_STEPS=20
SAVE_TOTAL_LIMIT=100
WEIGHT_DECAY=0.1
MAX_GRAD_NORM=1.0
BF16=1
FP16=0
TRUNCATION_SIDE="left"
SEED=42
TRUST_REMOTE_CODE=0
RESUME_FROM_CHECKPOINT=""

USE_LORA=0
LORA_R=8
LORA_ALPHA=32
LORA_DROPOUT=0.05
LORA_TARGET_MODULES="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj"
LORA_BIAS="none"
LORA_TASK_TYPE="CAUSAL_LM"

if [[ "$REPORT_TO" == "wandb" ]]; then
  export WANDB_PROJECT="$WANDB_PROJECT"
  if [[ -n "$RUN_NAME" ]]; then
    export WANDB_NAME="$RUN_NAME"
  fi
fi

export BATCH_WEIGHT_NORMAL="False"

deepspeed --num_gpus "$NUM_GPUS" src/weight_sft.py \
  --model_name_or_path "$MODEL_NAME_OR_PATH" \
  --train_jsonl "$TRAIN_JSONL" \
  --output_dir "$OUTPUT_DIR" \
  --max_prompt_length "$MAX_PROMPT_LENGTH" \
  --max_completion_length "$MAX_COMPLETION_LENGTH" \
  --per_device_train_batch_size "$PER_DEVICE_TRAIN_BATCH_SIZE" \
  --gradient_accumulation_steps "$GRAD_ACCUM_STEPS" \
  --learning_rate "$LEARNING_RATE" \
  --num_train_epochs "$NUM_TRAIN_EPOCHS" \
  --warmup_ratio "$WARMUP_RATIO" \
  --gamma "$GAMMA" \
  --beta "$BETA" \
  --logging_steps "$LOGGING_STEPS" \
  --save_steps "$SAVE_STEPS" \
  --save_total_limit "$SAVE_TOTAL_LIMIT" \
  --weight_decay "$WEIGHT_DECAY" \
  --max_grad_norm "$MAX_GRAD_NORM" \
  --truncation_side "$TRUNCATION_SIDE" \
  --seed "$SEED" \
  --zero_stage "$ZERO_STAGE" \
  --report_to "$REPORT_TO" \
  $( [[ -n "$RUN_NAME" ]] && echo --run_name "$RUN_NAME" ) \
  --eval_split_ratio "$EVAL_SPLIT_RATIO" \
  --eval_split_seed "$EVAL_SPLIT_SEED" \
  --eval_steps "$EVAL_STEPS" \
  --disable_shuffle \
  $( [[ -n "$PER_DEVICE_EVAL_BATCH_SIZE" ]] && echo --per_device_eval_batch_size "$PER_DEVICE_EVAL_BATCH_SIZE" ) \
  $( [[ -n "$DEEPSPEED_CONFIG_PATH" ]] && echo --deepspeed_config_path "$DEEPSPEED_CONFIG_PATH" ) \
  $( [[ "$BF16" == "1" ]] && echo --bf16 ) \
  $( [[ "$FP16" == "1" ]] && echo --fp16 ) \
  $( [[ "$TRUST_REMOTE_CODE" == "1" ]] && echo --trust_remote_code ) \
  $( [[ -n "$RESUME_FROM_CHECKPOINT" ]] && echo --resume_from_checkpoint "$RESUME_FROM_CHECKPOINT" ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --use_lora ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --lora_r "$LORA_R" ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --lora_alpha "$LORA_ALPHA" ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --lora_dropout "$LORA_DROPOUT" ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --lora_target_modules "$LORA_TARGET_MODULES" ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --lora_bias "$LORA_BIAS" ) \
  $( [[ "$USE_LORA" == "1" ]] && echo --lora_task_type "$LORA_TASK_TYPE" )
