


BASE_MODEL="baffo32/decapoda-research-llama-7B-hf"
DATA_PATH="alpaca_data_dq_k5_1k.json"
OUTPUT_DIR="./lora-alpaca-tmp"
BATCH_SIZE=64
MICRO_BATCH_SIZE=8
NUM_EPOCHS=15
LEARNING_RATE=1.5e-4
CUTOFF_LEN=512
VAL_SET_SIZE=1000
LORA_R=16
LORA_TARGET_MODULES='[q_proj,k_proj,v_proj,o_proj]'

echo "Batch size: $BATCH_SIZE, Micro batch size: $MICRO_BATCH_SIZE, Epochs: $NUM_EPOCHS"


COMMON_ARGS="--base_model $BASE_MODEL \
    --data_path $DATA_PATH \
    --output_dir $OUTPUT_DIR \
    --num_epochs $NUM_EPOCHS \
    --val_set_size $VAL_SET_SIZE \
    --learning_rate $LEARNING_RATE \
    --batch_size $BATCH_SIZE \
    --cutoff_len $CUTOFF_LEN \
    --lora_target_modules $LORA_TARGET_MODULES \
    --lora_r $LORA_R \
    --micro_batch_size $MICRO_BATCH_SIZE"


mkdir -p $OUTPUT_DIR


for sampling_tag in "stratified" "pruning" "normal"; do
    echo "Training with $sampling_tag sampling..."
    
    accelerate launch --main_process_port 0 finetune_infobatch.py \
        --sampling_tag "$sampling_tag" \
        $COMMON_ARGS
    
    echo "Completed training with sampling_tag: $sampling_tag"
    echo "----------------------------------------"
done

echo "All training runs completed!"
echo "Results saved in: $OUTPUT_DIR"



