# Low-Rank Compression of Language Models via Differentiable Rank Selection
Model compression using low-rank decomposition and learnable rank selection

## Setup
### Installation:
* install torch > 2.1 with cuda
* `pip install -r requirements.txt`
* `pip install lm-eval==0.4.2`

* Wandb account has to be logged in through the environment, to log final metrics 
  

## Reproduce
#### Reproduce llama-2-7b
```
MAX_LEN=256
NUM_TRAIN_SAMPLES=3000
DISTILL_MODE="hs_last"
LR=1e-2
EPOCHS=30
SCALE_COMP=1.
eval_freq_steps=0 # for pre-training mode only
TV_SCALE=1. # scale for tv loss 
EVAL_BS=8
BATCH_SIZE=4
EPOCHS=30
SCHED_DIST=3

MODEL=meta-llama/Llama-2-7b-hf
CACHE_DIR=cache_train_llama2 # add your path where models are downloaded 
COMP_VALUES=(0.90 0.85 0.80) # param ratios to run. If you want to just compress by 20%, set COMP_VALUES=(0.80)


for i in ${!COMP_VALUES[@]}; do
    COMP=${COMP_VALUES[$i]}

    if [ -z "$DISTILL_MODE" ]; then
        EXP_NAME="${MODEL#*/}_pretrain_${COMP}"
    else
        EXP_NAME="${MODEL#*/}_distill_${COMP}_v2_${SCHED_DIST}"
    fi

    # Check if it's the first iteration
    if [ $i -eq 0 ]; then
        # Command for the first iteration without extra arguments
        python train.py --eval_full --tv_loss=1. --bias_init --alpha=0.5 --lr_schedule="plateau" \
        --schedule_distillation=$SCHED_DIST --scale_compression=$SCALE_COMP --target_param_ratio=$COMP --mask_eval_type="threshold" --act_aware=activation \
        --model_name=$MODEL --epochs=$EPOCHS --eval_freq=3 --distill_mode=$DISTILL_MODE \
        --batch_size=$BATCH_SIZE --lr=$LR --num_train_samples=$NUM_TRAIN_SAMPLES --exp_name=$EXP_NAME \
        --max_length=$MAX_LEN --cache_dir=$CACHE_DIR --save_model=reconstruct \
        --eval_batch_size=$EVAL_BS --eval_freq_steps=$eval_freq_steps
    else

        python train.py --eval_full --tv_loss=1. --bias_init --alpha=0.5 --lr_schedule="plateau" \
        --schedule_distillation=$SCHED_DIST --scale_compression=$SCALE_COMP --target_param_ratio=$COMP --mask_eval_type="threshold" --act_aware=activation \
        --model_name=$MODEL --epochs=$EPOCHS --eval_freq=3 --distill_mode=$DISTILL_MODE \
        --batch_size=$BATCH_SIZE --lr=$LR --num_train_samples=$NUM_TRAIN_SAMPLES --exp_name=$EXP_NAME \
        --max_length=$MAX_LEN --cache_dir=$CACHE_DIR --save_model=reconstruct \
        --eval_batch_size=$EVAL_BS --eval_freq_steps=$eval_freq_steps --load_distill_cache --load_act_cache
    fi
done
```

