## Sparse Linear Attention

This is official repo for **S**parse **L**inear **A**ttention.

### Usage

Class `SparseLinearAttention` in `sparse_linear_attention.core` is the main interface of `SLA`.

```python
import torch
from sparse_linear_attention import SparseLinearAttention

attn = SparseLinearAttention(
    head_dim=128,
    topk=0.2,                     # = 1 - sparsity
    feat_kernel="softmax",        # options: hedgehog, elu, relu, softmax
    BLKQ=64,
    BLKK=64,
).cuda()


B, H, L, D = 2, 4, 4096, 128
q = torch.randn((B, H, L, D), dtype=torch.bfloat16, device='cuda')
k = torch.randn((B, H, L, D), dtype=torch.bfloat16, device='cuda')
v = torch.randn((B, H, L, D), dtype=torch.bfloat16, device='cuda')

o, sparsity = attn(q, k, v, return_sparsity=True)
print("sparisity:", sparsity)
print(o.norm().item())
```

### Benchmark

Test kernel speed:
```bash
python benchmark.py
```

The results are similar to

```plaintext
fused-attention-batch2-head16-d128-fwd-causal=False:
     N_CTX          SLA  Flash Attention
0  32760.0  2991.971936       219.057535
fused-attention-batch2-head16-d128-bwd-causal=False:
     N_CTX          SLA  Flash Attention
0  32760.0  1478.160098       220.197704
```

![](results/acceleration.png)
