# Residual Linear Attention - Supplementary Code

This repository contains the supplementary code for our ICLR 2026 paper submission on **Enhancing Linear Attention with Residual Learning**. Our work introduces novel architectures that enhance linear attention mechanisms with explicit residual fitting, achieving performance improvements while maintaining computational efficiency.

**Note**: This code is provided solely for reproducibility and enhanced paper review. It is not recommended for real-world practice. Some code snippets are integrated from https://github.com/fla-org/flash-linear-attention.

## 🚀 Key Features

- **Residual Linear Attention (RLA)**: Enhanced linear attention with explicit residual fitting
- **Residual Delta Net (RDN)**: Enhanced delta net with explicit residual fitting

## 🛠️ Prerequisites
- Python 3.12+
- PyTorch 2.4+
- Triton 3.2.0+
- Flash Attention 2.6+ (for benchmarking)

## 🎯 Quick Start

### Basic Usage

```python
import torch
from residual_linear_attention import rla_prefill, rdn_prefill

# Generate example data
batch_size, seq_len, num_heads, head_dim = 2, 16384, 32, 128
device = "cuda"
dtype = torch.bfloat16

# Query, Key, Value tensors
q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)

# Alpha parameter (less than 0 and in log scale)
alpha = torch.rand(batch_size, seq_len, num_heads, device=device, dtype=torch.float32).log()

# Beta parameter (in [0, 1])
beta = torch.sigmoid(torch.randn(batch_size, seq_len, num_heads, device=device, dtype=torch.float32))

# Gamma parameter (in [0, 1])
gamma = torch.sigmoid(torch.randn(batch_size, seq_len, num_heads, device=device, dtype=torch.float32))

# Residual Linear Attention  
o_rla, S_rla, R_rla = rla_prefill(q, k, v, alpha, beta, gamma)

# Residual Delta Net
o_rdn, S_rdn, R_rdn = rdn_prefill(q, k, v, alpha, beta, gamma)

print(f"Output shapes: {o_rla.shape}")  # [batch_size, seq_len, num_heads, head_dim]
```

### Advanced Usage with Custom Parameters

```python
# Advanced configuration (RLA and RDN have the same kwargs)
o_rla, S_rla, R_rla = rla_prefill(
    q, k, v, alpha, beta, gamma,
    cu_seqlens=None,            # Variable length version; the first dim of all inputs should be 1 if set
    initial_S=None,             # Initial state matrix for S; if None, initial state is all zeros
    initial_R=None,             # Initial state matrix for R; if None, initial state is all zeros
    output_final_state=True,    # Whether to output final state (S and R)
    scale=None,                 # Auto-computed as head_dim^-0.5
    rclip=1.0,                  # Residual clipping factor
    l2_qk_norm=True,            # L2 normalization for Q/K
)
```

## 📈 Performance Benchmarking

We provide comprehensive benchmarks comparing our methods with Flash Attention. Run the included test notebook `test.ipynb` for detailed performance analysis.

### Benchmark Results Summary

The following results were obtained on a single GPU with the configuration:
- Batch size: 1
- Number of heads: 32  
- Head dimension: 128
- Data type: bfloat16

Our method doubles the computation of respective baselines due to the explicit residual fitting process, but still shows linear scaling with respect to sequence length and will be faster than full attention by a large margin when the sequence is sufficiently long.

| Sequence Length| Flash Attention| sGLA   | RLA     | GDN     | RDN     |
|----------------|----------------|--------|---------|---------|---------|
| 1K             | 0.058ms        | 0.67ms | 1.40ms  | 0.63ms  | 1.10ms  |
| 2K             | 0.156ms        | 0.71ms | 1.41ms  | 0.82ms  | 1.41ms  |
| 4K             | 0.459ms        | 0.48ms | 0.72ms  | 0.69ms  | 1.87ms  |
| 8K             | 1.594ms        | 0.69ms | 1.47ms  | 0.85ms  | 1.80ms  |
| 16K            | 5.991ms        | 0.99ms | 2.22ms  | 1.66ms  | 3.53ms  |
| 32K            | 23.571ms       | 1.98ms | 4.40ms  | 3.31ms  | 6.97ms  |
| 64K            | 98.274ms       | 3.92ms | 8.72ms  | 6.56ms  | 13.90ms |
| 128K           | 403.150ms      | 7.78ms | 17.35ms | 13.07ms | 27.52ms |
