#!/bin/bash

# Check if an iteration id is provided
if [ $# -eq 0 ]
then
    echo "No iteration id provided. Usage: $0 <iter_id>"
    exit 1
fi

iter_id=$1
prev_iter_id=$((iter_id-1))

echo "Iteration id: $1"

# Branch based on iteration id
if [ "$iter_id" -eq 0 ]; then
    model_id="meta-llama/Meta-Llama-3-8B-Instruct"
    prior_data_flag=""  # No prior data flag for iteration 0
else
    model_id="anonymous/Meta-Llama-3-8B-Instruct-sft-intercode-sql-iter${prev_iter_id}"
    prior_data_flag="--prior_data_dir data/intercode_sql/sft/iter0"  # Prior data flag for other iterations
fi

python scripts/train/sft.py \
    --data_dir data/intercode_sql/sft/iter${iter_id} \
    ${prior_data_flag} \
    --model_id ${model_id} \
    --output_dir save/2408/intercode_sql/sft/iter${iter_id} \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --num_train_epochs 1 \
    --gradient_checkpointing True \
    --max_seq_length 6000 \
    --packing False \
    --torch_dtype bfloat16 \
    --optim adamw_torch_fused \
    --learning_rate 3e-5 \
    --evaluation_strategy steps \
    --eval_steps 200 \
    --save_strategy steps \
    --save_steps 200 \
    --save_total_limit 3 \
    --load_best_model_at_end True \
    --metric_for_best_model eval_loss \
    --use_peft True \
    --lora_alpha 64 \
    --lora_r 128 \
    --lora_dropout 0.05 \
    --lr_scheduler_type cosine \
    --max_grad_norm 0.3 \
    --warmup_steps 10 \
    --bf16 \
    --seed 42 \
    --report_to wandb \
    --logging_first_step \
    --logging_steps 10 \
    --push_to_hub False \