# StatQAT

StatQAT is a lightweight Quantization-Aware Training (QAT) framework focused on statistical and analytic update rules. It provides:

- A composable set of quantization data formats (uniform ints, FP4/FP8 variants, NF4/SF4)
- Pluggable algorithms (minmax, iterative, analytic, octav, lsq) with STE- or custom-gradient backprop
- Drop-in quantized layers (`QLinear`, `QConv2d`) with per-tensor, per-channel, or block-wise weight quantization
- Simple configuration objects (`QuantConfig`, `QuantModuleConfig`) and utilities to patch existing models

The project ships with a small, fast pytest suite and runnable experiments for CNNs and LLMs.

## Installation

Requirements: Python 3.10+, PyTorch 2.6+ (CPU is sufficient for tests).

Option A — editable install:

```bash
pip install -e .
```

Option B — from requirements:

```bash
pip install -r requirements.txt
```

Formatting/linting (ruff):

```bash
make format
```

## Quickstart

Create a quantized linear layer with INT4 weights and FP4 activations using the MinMax algorithm:

```python
import torch
from quant_mp.config import QuantConfig, QuantModuleConfig
from quant_mp.datatypes.template import get_data_format
from quant_mp.algs.template import get_algorithm
from quant_mp.QModules import QLinear

act = QuantConfig(
    qval_data_format=get_data_format("fp4_e2m1"),
    qparam_data_format=get_data_format("fp32"),
    algorithm=get_algorithm("minmax"),
)
wt = QuantConfig(
    qval_data_format=get_data_format("int4"),
    qparam_data_format=get_data_format("fp32"),
    algorithm=get_algorithm("minmax"),
    qblock_size="channel",  # per-output-channel
)
qcfg = QuantModuleConfig(activation=act, weight=wt)

layer = QLinear(768, 768, qlinear_config=qcfg)
x = torch.randn(2, 768)
y = layer(x)
```

Patch an existing model (e.g., HF transformer) in-place:

```python
from transformers import AutoModelForCausalLM
from quant_mp.utils import patch_model

model = AutoModelForCausalLM.from_pretrained("facebook/MobileLLM-125M")
patch_model(model, qcfg)  # replaces Linear/Conv with QLinear/QConv2d
```

## Data Formats

- Uniform integer: `int2`, `int3`, `int4`, `int8`
- Floating point: `fp4_e2m1`, `fp4_e3m0`, `fp8_e4m3`, `fp8_e4m3fnuz`, `fp8_e5m2`, `fp32`
- Non-uniform: `nf4`, `sf4-v5`

Access via `quant_mp.datatypes.template.get_data_format(name)`.

## Algorithms

- `minmax` (fit + STE)
- `iterative` (iterative fit + STE)
- `analytic` (closed-form helpers for SNR/levels)
- `octav` (iterative with outside/inside masks)
- `lsq` (learned step size; basic variant included)

Access via `quant_mp.algs.template.get_algorithm(name, algorithm_init_kwargs=...)`.

## Experiments

- ResNet/CIFAR: `python exps/run_exp_resnet.py`
- LLM QAT/FT: `python exps/run_exp_llm.py --help` (uses Hugging Face; will download models/datasets unless cached)
- LLM eval: `python exps/eval_llm.py`
- Figures: `exps/gen_fig_*.py` generate grids/SNR visuals

The file `exps/qat_config.py` contains sample configurations used by `exps/run_exp.py`.

## Testing

Run the CPU-friendly test suite:

```bash
export CUDA_VISIBLE_DEVICES=""  # force CPU if a GPU is present
pytest -q
```
