# LATMiX: Learnable Affine Transformations for Microscaling Quantization of LLMs

This repository contains the official implementation of **LATMiX** (Learnable Affine Transformations for Microscaling Quantization of LLMs), a novel quantization algorithm for Large Language Models (LLMs) using microscaling (MX) quantization formats with learned affine transformations.

## Overview

LATMiX enables efficient low-bit quantization of LLMs by learning optimal affine transformations that redistributes activations outliers and minimize quantization error. 
The algorithm supports various microscaling formats (e.g., MXFP4, FP4_E2M1) and post-training quantization methods (RTN, GPTQ) to achieve state-of-the-art accuracy-efficiency trade-offs.

<p align="center">
  <img src="LATMiX.svg" alt="LATMiX Algorithm Overview" width="800"/>
</p>

<p align="center">
  <em>Figure 1: LATMiX takes into account both the MX block structure and the distribution of features to diffuse outliers.</em>
</p>

[//]: # (## Key Features)

[//]: # ()
[//]: # (- **Microscaling Quantization**: Support for various MX formats &#40;MXFP4, FP4_E2M1, etc.&#41;)

[//]: # (- **Learned Affine Transforms**: Optimized transformations for both weights and activations)

[//]: # (- **Multiple PTQ Algorithms**: Round-to-Nearest &#40;RTN&#41; and GPTQ support)

[//]: # (- **Flexible Configuration**: Extensive command-line interface for experimentation)

[//]: # (- **MLflow Integration**: Built-in experiment tracking and logging)

## Installation

### Prerequisites

- Python 3.12 or higher
- CUDA-compatible GPU (recommended)
- 16GB+ GPU memory for most models

### Environment Setup

1. Create a virtual environment (recommended):
```bash
python3 -m venv venv
source venv/bin/activate
```

2. Install required packages:
```bash
pip install -r requirements.txt
```

3. Install the fast-hadamard-transform package:
```bash
pip install --no-build-isolation --no-cache-dir \
  "fast-hadamard-transform @ git+https://github.com/Dao-AILab/fast-hadamard-transform.git"
```

## Quick Start

### Running an Example

Run quantization on LLaMA 3.2-1B with learned affine transforms:

```bash
./scripts/llama_1b_mxfp4_gptq_ppl.sh
```

Or run directly with Python:

```bash
python3 main.py \
  --model_name meta-llama/Llama-3.2-1B \
  --calibration_dataset wikitext-2 \
  --n_samples 256 \
  --base_dtype bfloat16 \
  --weight_quant_format fp4_e2m1 \
  --act_quant_format fp4_e2m1 \
  --ptq_alg gptq \
  --transform_class_r1 learned_affine \
  --transform_class_r2 learned_affine \
  --matrix_init hadamard \
  --mat_param learnable_inv \
  --learning_rate 3e-5 \
  --weight_decay 0.05 \
  --max_steps 1000 \
  --lr_scheduler cosine \
  --warmup_iters 15 \
  --reg_lambda 1e-1 \
  --batch_size 8 \
  --device cuda
```

### Available Scripts

- `llama_1b_mxfp4_gptq_ppl.sh` - LLaMA 3.2-1B quantization with GPTQ evaluated on WikiText2 perplexity
- `qwen_1.7b_mxfp4_gptq_ppl.sh` - Qwen 3-1.7B quantization with GPTQ evaluated on WikiText2 perplexity
- `llama_1b_mxfp4_gptq_zshot.sh` - LLaMA 3.2-1B quantization with GPTQ evaluated on five 0-shot tasks
- `qwen_1.7b_mxfp4_gptq_zshot.sh` - Qwen 3-1.7B quantization with GPTQ evaluated on five 0-shot tasks


## Command-Line Arguments

Key arguments for customization:

### Model Configuration
- `--model_name`: HuggingFace model identifier
- `--calibration_dataset`: Dataset for calibration (e.g., wikitext-2)
- `--n_samples`: Number of calibration samples (default: 256)

### Quantization Settings
- `--weight_quant_format`: Weight quantization format (mxfp4, etc.)
- `--act_quant_format`: Activation quantization format
- `--ptq_alg`: PTQ algorithm (rtn, gptq)

### Transform Configuration
- `--transform_class_r1`: T1 transformation type (identity, learned, learned_affine)
- `--transform_class_r2`: T2 transformation type
- `--matrix_init`: Initialization method (orthogonal, hadamard, identity)
- `--mat_param`: Matrix parameterization (learnable, learnable_inv, learnable_qr)

### Optimization Settings
- `--learning_rate`: Learning rate for transform optimization
- `--max_steps`: Maximum optimization steps
- `--lr_scheduler`: Learning rate scheduler (cosine, linear, none)
- `--reg_lambda`: Regularization coefficient

For a complete list of arguments, run:
```bash
python3 main.py --help
```

## Project Structure

```
mx_algs/
├── main.py                      # Main entry point
├── arg_handler.py               # CLI argument parsing
├── requirements.txt             # Python dependencies
├── models/                      # Model loading utilities
├── quantization/               # Quantization implementation
│   ├── model_quantization.py
│   ├── quantizer.py
│   ├── transforms/             # Transform implementations
│   └── quantization_flows/     # RTN and GPTQ algorithms
├── transform_optimization/     # Transform training
├── evaluate/                   # Evaluation utilities
└── project_utils/              # Helper utilities
```

## Acknowledgments

This implementation is built on top of the open-source repository provided by:

**FP-Quant**: https://github.com/IST-DASLab/FP-Quant

We thank the authors for making their code publicly available.

## License

This project is licensed under the Apache License 2.0 - see below for details.

```
Copyright 2026 LATMiX Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

