# Bi-LoRA: Efficient Sharpness-Aware Minimization for Fine-Tuning Large-Scale Models

Our code framework is built upon [LoRA-GA](https://github.com/Outsider565/LoRA-GA).

## Overview

Fine-tuning large pre-trained models with limited data poses generalization challenges. While Sharpness-Aware Minimization (SAM) improves generalization by finding flat minima, its high memory and computational cost make it impractical for large models. Integrating SAM with parameter-efficient fine-tuning methods like LoRA is promising, but directly applying SAM to LoRA parameters restricts sharpness optimization, limiting its effectiveness.

To overcome this, we propose **Bi**-directional **Lo**w-**R**ank **A**daptation (Bi-LoRA), which introduces an auxiliary LoRA module to model SAM’s adversarial weight perturbations. This design decouples SAM’s perturbations from LoRA optimization: the main LoRA module adapts to tasks via gradient descent, while the auxiliary module captures sharpness via gradient ascent. This enables Bi-LoRA to achieve flatter minima efficiently while avoiding SAM’s doubled training cost. Extensive experiments demonstrate Bi-LoRA’s effectiveness in enhancing generalization across tasks and architectures.

![bi_lora_framework2](./resource/figures/bi_lora_framework.png)

## Quick start

### 1. Install Dependencies

```python
conda create -n bi_lora python=3.10.13
conda activate bi_lora
cd bi_lora
pip install -r requirements.txt
```

> Note: flash-attn must be compatible with both CUDA and PyTorch. In our experiments, we use CUDA 11.8 and PyTorch 2.3.0. In this case, we install `flash-attn==2.5.8`.

### 2. Train Llama 2-7B on the Math Task with Bi-LoRA

Run the following command to fine-tune Llama 2-7B on MetaMathQA using Bi-LoRA:

```python
cd bi_lora
bash scripts/run_math.sh
```

The `run_math.sh` script executes the following training command:

```python
python ../run_exp.py \
    +init=default \
    +peft=bi_lora \
    model=llama2 \
    wandb=run \
    ++dataset_name=meta_math \
    ++seed=42 \
    ++use_flash_attn=True \
    ++init.dtype=fp32 \
    ++peft.lora_target_modules=all \
    ++peft.lora_type=bi_lora \
    ++peft.lora1_rank=8 \
    ++peft.rho=0.1  \
    ++peft.bi_lora.lora2_rank=8 \
    ++model.learning_rate=5e-4 \
    ++model.early_stopping_patience=1e9 \
    ++model.epochs=2 \
    ++peft.bi_lora.exceed_rho=True \
    ++model.eval_epochs=1 \
    ++model.per_device_batch_size=4
```

It sets up the training process, defining key parameters such as:

- **Model**: Llama 2-7B
- **Dataset**: MetaMathQA
- **Micro-Batch Size**: 2
- **Training Epochs**: 2
- **LoRA Variant**: Bi-LoRA

For more details, check `scripts/run_math.sh` and modify parameters as needed.