#!/bin/bash

export CUDA_VISIBLE_DEVICES=0

# we use the Llama 2-7b-hf model, training on a subset of meta-math dataset,
# with 100k samples

PROJECT=BiLoRA
MODEL=llama2
WANDB=default

run_llama2_experiment() {
    local LORA_TYPE=$1
    local seed=$2
    local dataset_name=$3
    local rho=$4
    local lora1_rank=$5
    local lora2_rank=$6
    local micro_batch_size=$7

    python run_exp.py \
        +init=default \
        +peft=$LORA_TYPE \
        model=$MODEL \
        wandb=$WANDB \
        ++dataset_name=$dataset_name \
        ++seed=$seed \
        ++use_flash_attn=False \
        ++init.dtype=fp32 \
        ++peft.lora_target_modules=all \
        ++peft.lora_type=$LORA_TYPE \
        ++peft.lora1_rank=$lora1_rank \
        ++peft.rho=$rho  \
        ++peft.bi_lora.lora2_rank=$lora2_rank \
        ++model.learning_rate=5e-4 \
        ++model.early_stopping_patience=1e9 \
        ++model.epochs=2 \
        ++peft.bi_lora.exceed_rho=True \
        ++model.eval_epochs=1 \
        ++model.per_device_batch_size=$micro_batch_size \
        ++wandb.project=$PROJECT \
        ++wandb.group=$dataset_name\_$LORA_TYPE
}

SEED=0
RHO=0.05
LORA_TYPE=bi_lora
DATASET_NAME=meta_math
LORA1_RANK=8
LORA2_RANK=8
MICRO_BATCH_SIZE=4

run_llama2_experiment $LORA_TYPE $SEED $DATASET_NAME $RHO $LORA1_RANK $LORA2_RANK $MICRO_BATCH_SIZE
