# FourierRoFormer: Learned Fourier Attention for Vision Transformers

This repository contains the official PyTorch implementation of the paper "FourierRoFormer: Learned Fourier Attention for Vision Transformers" submitted to NeurIPS 2025. FourierRoFormer enhances Vision Transformers with frequency-aware attention through learned Fourier components, improving multi-scale feature representation.

## Introduction

Vision Transformers (ViTs) excel at long-range reasoning yet remain limited in how they represent spatial frequency and how rapidly attention should decay with token distance. FourierRoFormer addresses these limitations by augmenting rotary positional embeddings with:

1. A learned mixture of sinusoidal basis functions to capture patterns at different spatial frequencies
2. Learnable damping coefficients to control how strongly frequencies influence distant tokens

Our approach creates an attention kernel that is expressive, interpretable, and well-conditioned, leading to significant improvements in image classification accuracy with minimal computational overhead.

## Key Features

- **Frequency-aware attention**: Learns which frequency patterns are most relevant for visual understanding
- **Adaptive control**: Balances local detail and global context through learned parameters
- **Compatibility**: Maintains the geometric properties of RoPE while adding frequency sensitivity
- **Efficiency**: Introduces only 0.04% parameter increase (e.g., from 24.75M to 24.76M parameters)

## Requirements

- Python 3.8+
- PyTorch 1.12+
- torchvision
- wandb (optional, for experiment tracking)
- matplotlib
- numpy
- tqdm
- datasets (HuggingFace)

## Installation

```bash
# Clone the repository
git clone https://github.com/yourusername/fourier-roformer.git
cd fourier-roformer

# Install dependencies
pip install torch torchvision tqdm matplotlib numpy wandb datasets

# Optional: Login to wandb for experiment tracking (or set WANDB_MODE=offline to disable)
wandb login
```

## Usage

### Basic Training

```bash
# Train FourierRoFormer on CIFAR-10
python train_fourier_roformer.py --dataset cifar10 --batch_size 128 --epochs 200 \
    --d_model 384 --n_heads 6 --n_layers 12 --variants fourier

# Compare with baseline models
python train_fourier_roformer.py --dataset cifar10 --batch_size 128 --epochs 200 \
    --d_model 384 --n_heads 6 --n_layers 12 --variants fourier roformer vit deit
```

### Model Configurations

The paper evaluates three model sizes:

- **Small**: `--d_model 192 --n_heads 6 --n_layers 6`
- **Medium**: `--d_model 384 --n_heads 6 --n_layers 12` (default)
- **Large**: `--d_model 576 --n_heads 12 --n_layers 12`

### Datasets

Supported datasets with appropriate configurations:

```bash
# CIFAR-10/100 (32×32 images, 4×4 patches)
python train_fourier_roformer.py --dataset cifar10  # or cifar100

# Oxford-Flowers102 (224×224 images, 16×16 patches)
python train_fourier_roformer.py --dataset oxford-flowers --img_size 224

# ImageNet-subset via Tiny-ImageNet (224×224 images, 16×16 patches)
python train_fourier_roformer.py --dataset imagenet-subset --img_size 224
```

### Ablation Studies

```bash
# Run comprehensive ablation studies on CIFAR-100
python train_fourier_roformer.py --dataset cifar100 --run_ablation

# For faster experimentation with smaller dataset portions
python train_fourier_roformer.py --dataset cifar100 --run_ablation --subset_size 100
```

## Reproducing Paper Results

To reproduce the main results from our paper:

```bash
# For statistical significance (p < 0.05), we use 5 random seeds
# Example: CIFAR-100 (Medium model)
for seed in 42 1337 2022 9999 8888; do
    python train_fourier_roformer.py --dataset cifar100 --batch_size 128 --epochs 200 \
        --d_model 384 --n_heads 6 --n_layers 12 --variants fourier --seed $seed
done
```

## Results

FourierRoFormer consistently outperforms baseline models with statistically significant margins:

| Model | CIFAR-10 | CIFAR-100 | Oxford-Flowers102 | ImageNet-subset |
|-------|----------|-----------|-----------------|-----------------|
| Standard ViT | 93.21±0.14 | 77.79±0.21 | 93.68±0.18 | 47.58±0.23 |
| DeiT | 94.58±0.12 | 79.55±0.18 | 94.75±0.15 | 46.95±0.19 |
| RoFormer | 94.63±0.11 | 78.42±0.19 | 94.23±0.16 | 49.51±0.22 |
| FourierRoFormer | **96.28±0.10** | **84.26±0.15** | **96.04±0.13** | **51.16±0.18** |

*All improvements are statistically significant (p < 0.05)*

## Code Structure

- **Attention Mechanisms**
  - `ScaledDotProductAttention`: Standard attention used in ViT and DeiT
  - `RotaryPositionalEncoding`: RoPE implementation for RoFormer
  - `FourierRoFormer`: Our novel attention with learned Fourier components

- **Model Components**
  - `EnhancedMultiHeadAttention`: Supports different attention variants
  - `EnhancedTransformerEncoderLayer`: Transformer layer with configurable attention
  - `EnhancedVisionTransformer`: Main vision transformer implementation

- **Training & Evaluation**
  - `prepare_dataset`: Dataset loading with appropriate preprocessing
  - `train_model`: Training loop with metrics logging
  - `run_experiment`: Comparative evaluation of different variants
  - `run_ablation_study`: Comprehensive ablation studies

- **Utilities**
  - `WandbLogger`: Experiment tracking and visualization
  - `generate_ablation_configs`: Ablation configuration generation
  - `create_balanced_subset`: Dataset subsetting for faster experimentation
