# CUTLASS Epilogues

## Introduction

This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.

Currently, we only support symmetric quantization for weights,
and symmetric and asymmetric quantization for activations.
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).

There are 4 epilogues:

1. `ScaledEpilogue`: symmetric quantization for activations, no bias.
1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias.
1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias.
1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias.

We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
Instead, if no bias is passed, the epilogue will use 0 as the bias.
That induces a redundant addition operation (and runtime check), but the performance impact is minor.

## Underlying Linear Algebra

More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).

If $` \widehat X `$ is the quantized $` X `$, our matrices become the following

```math
A = s_a (\widehat A - J_a z_a)
```

```math
B = s_b \widehat B
```

```math
D = A B + C
```

```math
D = s_a s_b \widehat D + C
```

Here, D is the output of the GEMM, and C is the bias.
A is the activations and supports asymmetric quantization,
and B is the weights and only supports symmetric quantization.
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
Additional epilogues would be required to support asymmetric quantization for weights.

Expanding further, we can calculate $` \widehat D `$ as follows:

```math
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
```

```math
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
```

```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```

Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
and $` J_a \widehat B `$ is known ahead of time.
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.

## Epilogues

### `ScaledEpilogue`

This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
The output of the GEMM is:

```math
\widehat D = \widehat A \widehat B
```

```math
D = s_a s_b \widehat D
```

```math
D = s_a s_b \widehat A \widehat B
```

Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).

### `ScaledEpilogueBias`

This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
The output of the GEMM is:

```math
\widehat D = \widehat A \widehat B
```

```math
D = s_a s_b \widehat D + C 
```

```math
D = s_a s_b \widehat A \widehat B + C
```

Epilogue parameters:

- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector).

### `ScaledEpilogueAzp`

This epilogue computes the asymmetric per-tensor quantization for activations with bias.
The output of the GEMM is:

```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
```

```math
D = s_a s_b \widehat D + C 
```

```math
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
```

Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
That is precomputed and stored in `azp_with_adj` as a row-vector.

Epilogue parameters:

- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
  - Generally this will be per-tensor as the zero-points are per-tensor.
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector).

To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.

### `ScaledEpilogueAzpPerToken`

This epilogue computes the asymmetric per-token quantization for activations with bias.

The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.

Epilogue parameters:

- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
  - Generally this will be per-token as the zero-points are per-token.
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
- `azp` is the zero-point (`z_a`), is per-token (column-vector).
- `bias` is the bias, is always per-channel (row-vector).

To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.

The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):

```math
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
```
