# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Overview

QuaRot (Quantization through Rotation) is a 4-bit quantization scheme for LLMs that uses orthogonal rotations to eliminate outliers in activations. The key innovation is applying computational invariance through Hadamard transforms and rotations to enable end-to-end 4-bit quantization of weights, activations, and KV cache without retaining any channels in higher precision.

## Build and Installation

```bash
# Clone and install with CUDA kernels
git clone https://github.com/spcl/QuaRot.git
cd QuaRot
pip install -e .  # Compiles CUDA kernels during installation
```

Dependencies:
- PyTorch 2.2.1
- transformers 4.38.0
- CUDA-capable GPU (compute capability: sm_75, sm_80, sm_86)
- CMake (for third-party dependencies)

## Core Architecture

### Key Components

1. **Rotation Framework** (`fake_quant/rotation_utils.py`):
   - `fuse_layer_norms()`: Fuses LayerNorm operations into adjacent linear layers
   - `rotate_model()`: Applies orthogonal/Hadamard rotations to model weights
   - Rotations applied to: embeddings, attention (Q,K,V,O), MLP layers, and LM head

2. **Quantization** (`quarot/`):
   - Int4 CUDA kernels for quantized operations
   - `matmul()`: 4-bit matrix multiplication (requires dimensions divisible by 32)
   - `sym_quant()`/`sym_dequant()`: Symmetric quantization operations
   - `PackedQuantizedTensor`: Container for quantized weights with scales

3. **Hadamard Transforms**:
   - Applied to attention outputs (V→O) and MLP outputs
   - Online Hadamard computation during inference
   - Fast Walsh-Hadamard transform from third-party library

### Directory Structure
- `quarot/`: Core package with CUDA kernels and PyTorch modules
  - `kernels/`: CUDA implementations (gemm, quantization, flashinfer)
  - `nn/`: PyTorch modules for quantized operations
  - `functional/`: Functional interfaces to CUDA operations
- `fake_quant/`: Simulation experiments (no actual kernel execution)
- `e2e/`: End-to-end quantized model implementations
  - `checkpoint_utils/`: Tools for quantizing and rotating checkpoints
  - `quantized_llama/`: Custom LLaMA implementation with QuaRot

## Common Commands

### Running Fake Quantization Experiments
```bash
cd fake_quant
# Basic 4-bit quantization with rotation
python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 4 --w_clip

# With zero-shot evaluation
python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --w_bits 4 --lm_eval --tasks winogrande,arc_easy

# Different calibration dataset
python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --w_bits 4 --cal_dataset c4
```

### Quantizing Model Checkpoints
```bash
cd e2e/checkpoint_utils
# Quantize and save model (uses GPTQ by default)
python quantize_llama_checkpoint.py \
  --pretraiend_path_or_name meta-llama/Llama-2-7b-hf \
  --save_path ./quarot_llama2_7b \
  --cal_dataset wikitext2 \
  --nsamples 128

# Using RTN instead of GPTQ
python quantize_llama_checkpoint.py \
  --pretraiend_path_or_name meta-llama/Llama-2-7b-hf \
  --save_path ./quarot_llama2_7b_rtn \
  --w_rtn
```

### Running Benchmarks
```bash
cd e2e
# Prefill-only benchmark
python benchmark.py --batch_size 1 --prefill_seq_len 2048

# Prefill + decode benchmark
python benchmark.py --batch_size 1 --prefill_seq_len 2048 --decode_steps 128

# Multiple batches
python benchmark.py --batch_size 8 --prefill_seq_len 512 --decode_steps 64
```

### Component Benchmarks
```bash
cd benchmarks
# Hadamard transform performance
python hadamard_benchmark.py

# Quantized linear layer performance
python qlinear_benchmark.py

# Quantized attention performance
python qattention_benchmark.py
```

## Key Implementation Details

### Rotation Process
1. **Layer Norm Fusion**: Fuses RMSNorm/LayerNorm weights into adjacent linear layers
2. **Orthogonal Rotation**: Applies random orthogonal or Hadamard rotation matrix Q
3. **Component Rotations**:
   - Embeddings: W → W·Q
   - Attention inputs: W_q/k/v → W·Q
   - Attention/MLP outputs: W → Q^T·W
   - Apply Hadamard to V→O and MLP outputs

### Quantization Parameters
- `{a,w,v,k}_bits`: Bits for activations/weights/values/keys (default: 4)
- `{a,w,v,k}_groupsize`: Group size for quantization (-1 for token-wise)
- `{a,w,v,k}_asym`: Asymmetric quantization flags
- `{a,w,v,k}_clip_ratio`: Clipping ratios for outlier suppression

### Model Support
Currently supports:
- LLaMA-2 (7B, 13B, 70B)
- LLaMA-3 (with minor modifications)
- Basic OPT model support

## Important Notes

- All matrix dimensions must be multiples of 32 for int4 operations
- Head dimension must be power of 2 for K-cache quantization
- Flash Attention 2 is used by default for attention
- Weight quantization uses GPTQ calibration (128 samples default) or RTN
- The setup process automatically compiles fast-hadamard-transform dependency