#!/bin/bash

# Check if an iteration id is provided
if [ $# -eq 0 ]
then
    echo "No iteration id provided. Usage: $0 <iter_id>"
    exit 1
fi

iter_id=$1
prev_iter_id=$((iter_id-1))

echo "Iteration id: $1"

# Define the parameter pairs to sweep over
declare -a learning_rates=("3e-5")
declare -a betas=("0.01")

desirable_weight=1.0
undesirable_weight=1.0

num_train_epochs=1
save_prefix="save/240927/"

if [ "$iter_id" -eq 1 ]; then
    init_model_id="anonymous/Meta-Llama-3-8B-Instruct-sft-alfworld-iter0"
else
    init_model_id="anonymous/Meta-Llama-3-8B-Instruct-kto-alfworld-lr5e-7-bt0.01-ep1-iter${prev_iter_id}"
fi

# Loop over the parameter pairs
for i in ${!learning_rates[@]}; do
    learning_rate=${learning_rates[$i]}
    beta=${betas[$i]}
    
    save_model_name="alfworld_kto_des${desirable_weight}_undes${undesirable_weight}_lr${learning_rate}_bt${beta}_ep${num_train_epochs}/iter${iter_id}"
    
    echo "Running training with learning_rate=${learning_rate}, beta=${beta}"
    echo "Saving model as: ${save_prefix}/${save_model_name}"
    echo "Initialize model from: ${init_model_id}"
    
    python scripts/train/kto.py \
        --data_dir data/alfworld/dpo/iter${iter_id} \
        --output_dir ${save_prefix}/${save_model_name} \
        --model_id ${init_model_id} \
        --per_device_train_batch_size 4 \
        --per_device_eval_batch_size 4 \
        --gradient_accumulation_steps 8 \
        --num_train_epochs ${num_train_epochs} \
        --gradient_checkpointing True \
        --max_length 8000 \
        --max_prompt_length 6000 \
        --torch_dtype bfloat16 \
        --optim adamw_torch_fused \
        --learning_rate ${learning_rate} \
        --eval_strategy steps \
        --eval_steps 250 \
        --save_strategy steps \
        --save_steps 250 \
        --save_total_limit 5 \
        --load_best_model_at_end True \
        --metric_for_best_model eval_loss \
        --use_peft True \
        --beta ${beta} \
        --desirable_weight ${desirable_weight} \
        --undesirable_weight ${undesirable_weight} \
        --lora_alpha 64 \
        --lora_r 128 \
        --lora_dropout 0.05 \
        --lr_scheduler_type cosine \
        --max_grad_norm 0.3 \
        --warmup_steps 10 \
        --bf16 \
        --seed 42 \
        --report_to wandb \
        --logging_first_step \
        --logging_steps 10 \
        --push_to_hub False
done