#!/bin/bash

# Check if an iteration id is provided
if [ $# -eq 0 ]
then
    echo "No correction id provided. Usage: $0 <correction_id>"
    exit 1
fi

correction_id=$1

echo "Correction id: $correction_id"

data_dirs=""
data_dirs+="data/alfworld/sft/iter0,"
data_dirs+="data/ablations/alfworld/sft_${correction_id}/iter1,"

# Remove the trailing comma
data_dirs=${data_dirs%,}

echo "Data directories: $data_dirs"

model_id="anonymous/Meta-Llama-3-8B-Instruct-sft-alfworld-iter0"


python scripts/train/sft.py \
    --data_dirs "${data_dirs}" \
    --output_dir save/2409/alfworld_sft_ablation/correction_${correction_id}/iter1 \
    --model_id ${model_id} \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --num_train_epochs 1 \
    --gradient_checkpointing True \
    --max_seq_length 8000 \
    --packing False \
    --torch_dtype bfloat16 \
    --optim adamw_torch_fused \
    --learning_rate 3e-5 \
    --evaluation_strategy steps \
    --eval_steps 200 \
    --save_strategy steps \
    --save_steps 200 \
    --save_total_limit 3 \
    --load_best_model_at_end True \
    --metric_for_best_model eval_loss \
    --use_peft True \
    --lora_alpha 64 \
    --lora_r 128 \
    --lora_dropout 0.05 \
    --lr_scheduler_type cosine \
    --max_grad_norm 0.3 \
    --warmup_steps 10 \
    --bf16 \
    --seed 42 \
    --report_to wandb \
    --logging_first_step \
    --logging_steps 10 \
    --push_to_hub False \