#!/bin/bash

set -x

# Define the base model and other static parameters
BASE_MODEL_PATH=${BASE_MODEL_PATH:-"/path/to/base/model"}

EVAL_DATASETS="gsm8k_test"

BASE_OUTPUT_DIR="saves/llama3-8b/meta_math/sft_mosa_param_equivalent"

for r in 32; do
  echo "--- Preparing to train MoSA with equivalent rank: ${r} ---"


  OUTPUT_DIR="${BASE_OUTPUT_DIR}_rank_${r}_lr_1e-5"

  # --- Checkpoint logic remains the same ---
  resume_option=()
  latest_checkpoint=$(ls -d ${OUTPUT_DIR}/checkpoint-*/ 2>/dev/null | sort -V | tail -n 1)
  if [[ -d "${latest_checkpoint}" ]]; then
    echo "--- Found latest checkpoint: ${latest_checkpoint}. Setting it for resume. ---"
    resume_option=(--resume_from_checkpoint "${latest_checkpoint}")
  else
    echo "--- No checkpoint found in ${OUTPUT_DIR}. Starting training from scratch. ---"
  fi
  # --- End of checkpoint logic ---

  llamafactory-cli train \
      --model_name_or_path ${BASE_MODEL_PATH} \
      --trust_remote_code \
      --stage sft \
      --do_train \
      --finetuning_type mosa \
      --mosa_equivalent_rank ${r} \
      --mosa_target_modules "all" \
      --mosa_grouping_strategy "balanced_seeded" \
      --mosa_grouping_seed 42 \
      --dataset meta_math \
      --template alpaca \
      --cutoff_len 2048 \
      --overwrite_cache \
      --preprocessing_num_workers 16 \
      --dataloader_num_workers 4 \
      --output_dir ${OUTPUT_DIR} \
      --logging_steps 10 \
      --save_steps 1000 \
      --plot_loss \
      --overwrite_output_dir \
      --save_only_model false \
      --report_to tensorboard \
      --per_device_train_batch_size 8 \
      --gradient_accumulation_steps 2 \
      --learning_rate 1.0e-5 \
      --num_train_epochs 3 \
      --lr_scheduler_type cosine \
      --warmup_ratio 0.1 \
      --bf16 \
      --ddp_timeout 180000000 \
      --eval_dataset ${EVAL_DATASETS} \
      --per_device_eval_batch_size 4 \
      --predict_with_generate \
      --eval_strategy steps \
      --do_eval \
      
      "${resume_option[@]}"
done

echo "--- All MoSA parameter-equivalent training runs completed. ---"
