DeltaFormer Pre-Attention
===

**Notes**
* The kernels are specifically tuned for `head_dim=128` and bfloat16.
Using other shapes and data formats may lead to unexpected performance and even
fatal bug.
* `T` has to be divisble by `128`
* Triton 3.1.0 is known to include bugs. Please use `triton==3.3.0`

### Function

Given `k` and `v`, calculate `u[i] = v[i] - softmax(k[i] @ k[:i].T) @ u[:i]`
with high performance and backward support.

### Usage

```python
import dfpa

k = ...
v = ...

u = dfpa.preattn(k, v)
```

### Testing

Run `test/test_attn.py`. It contains the following tests:

* Forward corectness check
* Backward corectness check
* Performance test for both passes
* PyTorch timeline profiling

### Performance

Expected performance results on H800:

* [ B, H, T, D ] = 2, 32, 8192, 128

Forward:

```
serial           time 279.908 ms
Torch trsv       time 102.156 ms
dfpa             time 12.705 ms
```

Backward:

```
Torch trsv backward time 275.659 ms
dfpa       backward time 25.713 ms
```

Reference time consumption in training

```
Typical normal attention    time 46.839 ms
MLP hto4h GeMM              time 2.918 ms
FlashAttention forward      time 3.406 ms
FlashAttention backward     time 9.348 ms
```
