# Extreme combined compression of large language models through joint optimization



joint_optimization is a simple and powerful compression technique for LLMs. 


## Contents
- [Install](#install)
- [Usage](#usage)

## Dependence
torch: tested on v2.0.1+cu118
transformers: tested on v4.31.0
accelerate: tested on v0.21.0
datasets: tested on v2.14.4
timm: tested on v0.9.5
## Install
```
conda create -n joint_optimization python=3.10 -y
conda activate joint_optimization
cd joint_optimization
install dependence
pip install --upgrade pip 
pip install -e .
```

We also leverage the kernel from [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) to achieve real quantization. So you should also install the bug-fixed AutoGPTQ as follows:
```
pip install auto_gptq
```
**Customized Cuda Operator**
```
cd models/ops
python setup.py install
```

## Usage
**example**:
1. Obtain the channel-wise scales and shifts required for initialization:

Optional, we also offer the script that you can generate channel-wise scales and shifts by yourself:
```
python generate_act_scale_shift.py --model /PATH/TO/your_model
```

2. joint_optimization
```
# llama-7b W2A16g128 wanda unstructured 75%
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/LLaMA/llama-7b  \
--epochs 20 --output_dir ./log/llama-7b-w2a16g128-wanda-0.75 \
--eval_ppl --wbits 2 --abits 16 --group_size 128 --lwc --let \
--sparsity_ratio 0.75 --sparsity_type unstructured --sparsity_method wanda 

# llama3-8b W3A16g128 dsnot unstructured 50%
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/LLaMA/llama3-8b  \
--epochs 20 --output_dir ./log/llama3-8b-w3a16g128-dsnot-0.5 \
--eval_ppl --wbits 3 --abits 16 --group_size 128 --lwc  \
--sparsity_ratio 0.5 --sparsity_type unstructured --sparsity_method dsnot \
--initial_method wanda \
--skip_layer no_skip \
--skip_sub_layer no_skip \
--max_cycle_time 50 \
--update_threshold 0.1 \
--pruning_granularity row \
--pow_of_var_regrowing 1 \
--without_same_sign \

# llama2-13b W2A16g128 besa unstructured 60%
CUDA_VISIBLE_DEVICES=0 python main.py \
--model /PATH/TO/LLaMA/llama2-13b  \
--epochs 20 --output_dir ./log/llama2-13b-w2a16g128-besa-0.6 \
--eval_ppl --wbits 3 --abits 16 --group_size 128 --lwc --let \
--sparsity_ratio 0.5 --sparsity_type unstructured --sparsity_method besa \
--let_lr 5e-4 --alpha 0.6 \
--blocksize 1 --sparsity-beta 5e0 \



```

3. weight-activation quantization

More detailed and optional arguments:
- `--model`: the local model path or huggingface format.
- `--wbits`: weight quantization bits.
- `--abits`: activation quantization bits.
- `--group_size`: group size of weight quantization. If no set, use per-channel quantization for weight as default.
- `--epochs`: training epochs. 
- `--nsamples`: number of calibration samples, 128 as default.
- `--eval_ppl`: evaluating the perplexity of quantized models.
- `--resume`: loading pre-trained parameters.
- `--multigpu`: to inference larger network on multiple GPUs
- `--real_quant`: real quantization, which can see memory reduce
- `--save_dir`: saving the quantization model for further exploration.
- `--sparsity_ratio`: target ratio of zeros in the data.
- `--sparsity_type`: method of achieving sparsity, defaulting to unstructured.
- `--sparsity_method`: prune method utilized to induce sparsity in the data



