export NCCL_DEBUG=WARNING

date=$(date +"%m%d")

echo "date: ${date}"

model="$1"
dataset="$2"
task_order="$3"
run_name="$4"

epoch=5
max_len=4096
lora_rank=8


echo "———————————————————— Train ————————————————————"


if [ "${model}" = "llama3" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/meta-llama/Meta-Llama-3-8B"
elif [ "${model}" = "llama3.1" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/meta-llama/Llama-3.1-8B"
elif [ "${model}" = "qwen" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/Qwen/Qwen3-8B-Base"
fi

adapter_arg=()
if (( task_order > 1 )); then
  prev=$(( task_order - 1 ))
  adapter_path="ckpt/cl/sft/${model}/${run_name}/${prev}"
  adapter_arg+=( --adapter_name_or_path "$adapter_path" )
fi


if [ "${dataset}" = "gsm8k" ]; then
  template="vanilla_qa"
else
  template="vanilla"
fi

echo "model: ${model}"
echo "epoch: ${epoch}"
echo "max_len: ${max_len}"
echo "dataset: ${dataset}"
echo "task_order: ${task_order}"
echo "run_name: ${run_name}"
echo "lora_rank: ${lora_rank}"
echo "adapter_name_or_path: ${adapter_name_or_path}"


find_free_port() {
    python3 -c "
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(('', 0))
    print(s.getsockname()[1])
"
}

main_process_port=$(find_free_port)

accelerate launch \
    --num_processes 1 \
    --main_process_port "${main_process_port}" \
    --config_file src/configs/accelerate_zero2.yaml \
    src/train_sft.py \
        --stage sft \
        --do_train true \
        --per_device_eval_batch_size 1 \
        --predict_with_generate false \
        --num_train_epochs "${epoch}" \
        --cutoff_len "${max_len}" \
        --model_name_or_path "${model_name_or_path}" \
        "${adapter_arg[@]}" \
        --dataset "${dataset}_train" \
        --dataset_dir "data/cl" \
        --template "${template}" \
        --mask_history true \
        --finetuning_type "lora" \
        --lora_target "q_proj,k_proj,v_proj,up_proj,down_proj" \
        --lora_rank "${lora_rank}" \
        --lora_dropout 0.05 \
        --output_dir "ckpt/cl/sft/${model}/${run_name}/${task_order}" \
        --overwrite_cache true \
        --per_device_train_batch_size 1 \
        --gradient_accumulation_steps 16 \
        --lr_scheduler_type "linear" \
        --warmup_steps 100 \
        --logging_strategy "steps" \
        --logging_steps 10 \
        --save_strategy "epoch" \
        --save_total_limit 1 \
        --learning_rate 3e-4 \
        --fp16 true \
        --overwrite_output_dir true \
        --report_to none