# LLM Spectral Clipping

This repository contains the initial implementation of the framework **SPECTRA** for LLM training, built on top of [llm-baselines](https://github.com/epfml/llm-baselines) and [nanoGPT](https://github.com/karpathy/nanogpt).

## Overview

This project investigates **spectral clipping** - a method to bound the spectral norm of optimizer updates by clipping singular values. 

### Current status

- **Spectral Clipping**: Efficient approximation of soft spectral clipping using Newton-Schulz iteration.
- **Spectral Normalization**: Muon-style normalization where all singular values are mapped to approximately 1.
- **Spectral Wrapper**: Works with any optimizer with the standard update rule (tb discussed later) and decoupled weight decay (AdamW, Signum, AdEMAMix, etc.).
- **Support Spectral Gradient Pre-Clipping**: Per-parameter spectral clipping for raw gradients before optimizer step.
- **Singular Value (SV) Recording**: Tools for analyzing singular value distributions of gradients and updates.
- **Gradient Noise Structure Analysis**: Top-k SVD for gradient noise with subspace distance metrics (spectral and chordal).
- **Update Noise Structure Analysis**: Analyzes the optimizer update noise structure using temporary optimizer (avoids fused AdamW state corruption).
- **Sum of L∞-Norms Logging**: Track `Σ_p ||p||_∞` for model parameters.

## Spectral Post-Processing

### Mathematical Background

**Soft Spectral Clipping** uses the approximation:
```
hc(X) = (I + XX^T / c^2)^{-1/2} @ X
```

This maps singular values: `sigma -> sigma / sqrt(1 + sigma^2/c^2)`, which:
- Preserves small singular values (sigma << c): approximately unchanged
- Clips large singular values (sigma >> c): approaches c
- Smooth transition (no discontinuity at threshold)

The matrix inverse square root is computed efficiently using Newton-Schulz iteration, avoiding expensive full SVD.

### Configuration Options

| Parameter | Description | Default |
|-----------|-------------|---------|
| `--spectral_post_process` | Processing mode: "none", "clip", "normalize" | "none" |
| `--spectral_clip_c` | Clipping threshold c | 10.0 |
| `--spectral_ns_steps` | Newton-Schulz iterations | 10 |
| `--spectral_apply_to` | Apply to "2d" (matrices only) or "all" parameters | "all" |

### Some details

The `SpectralOptimizer` wrapper:
1. Stores weights `X_k` before optimizer step
2. Calls base optimizer step (AdamW, Signum, etc.)
3. Recovers update direction: `U_k = (X_k - X_{k+1}) / lr - weight_decay * X_k`
4. Applies spectral post-processing to `U_k`
5. Recomputes weights with processed update

This approach works with any optimizer using the update rule of the form:

```
X_{k+1} = X_k - lr * (weight_decay * X_k + U_k)
```

where `U_k` is the optimizer-specific update direction (e.g., `sign(momentum)` for Signum).

After computing the post-processed `U_k`, we do the following update:
```
X_{k+1} = (1 - weight_decay * lr) * X_k - lr * α * U_k_processed
```
where `α = \max(\sqrt(m/n),1)` .


### Dynamic Clipping During Warmup

During learning rate warmup, the clipping threshold is adjusted so that `c * lr = constant`:
- Before warmup_steps: `c = (clip_c * base_lr) / current_lr`
- After warmup_steps: `c = clip_c` (fixed)

By default, we use dynamic clipping during warmup.
To disable this schedule, use `--disable_dynamic_clip` and the clipping threshold 
will stay constant. 

## Spectral Gradient Pre-Clipping

In addition to post-processing optimizer updates, we support **per-parameter spectral clipping of raw gradients** as a pre-processing step.

#### Difference from Global Gradient Clipping

| Method | What it bounds | Scope |
|--------|----------------|-------|
| `--grad_clip` (existing) | Global L2 norm `\|\|g\|\|_2` | All parameters together |
| `--spectral_grad_clip` (new) | Spectral norm `\|\|G\|\|_2 = σ_max(G)` | Per-parameter (2d only) |

### Configuration Options

| Parameter | Description | Default |
|-----------|-------------|---------|
| `--spectral_grad_clip` | Pre-clipping mode: "none" or "clip" | "none" |
| `--spectral_grad_clip_c` | Clipping threshold | 0.1 |

The options `--spectral_ns_steps` and `--disable_dynamic_clip` are shared 
for both pre and post spectral clipping.

### Order of Operations

1. Backward pass
2. **Spectral gradient clipping** (per-parameter)
3. Global L2 gradient clipping (`--grad_clip`)
4. Optimizer step

### Usage

```bash
# Only spectral gradient clipping
python ./src/main.py --opt adamw --spectral_grad_clip clip --spectral_grad_clip_c 1.0 --grad_clip 0.0

# Both: spectral first, then global L2
python ./src/main.py --opt adamw --spectral_grad_clip clip --spectral_grad_clip_c 1.0 --grad_clip 0.5

# Spectral grad clipping + Spectral update post-processing + Global L2
python ./src/main.py --opt adamw \
    --grad_clip 0.5 \
    --spectral_grad_clip clip --spectral_grad_clip_c 1.0 \
    --spectral_post_process clip --spectral_clip_c 10.0
```

## SVD Recording & Analysis

### 1: Recording Singular Values of Gradients and Updates

We record two quantities at specified training steps:
- **Gradient singular values** (`grad_sv`): Full SVs of raw gradients after backward pass
- **Update singular values** (`update_sv`): Full SVs of the optimizer update `U_k` (excluding weight decay)


**Layers recorded** (for a 12-layer LLaMA):
- `embedding`: Token embedding weights
- `early_attn/mlp`: Layer 0 attention and MLP weights
- `middle_attn/mlp`: Layer 6 attention and MLP weights
- `late_attn/mlp`: Layer 11 attention and MLP weights

#### **Usage:**
We refer to scripts/records/adamw_singvals.sh as an example.

### 2: Recording Noise Structure

We analyze the **spectral structure of gradient noise** - the difference between mini-batch gradients and the "true" gradient computed with a large batch.

- Let `G` be the "true" gradient (computed with large batch size, e.g., 4096)
- Let `g` be a stochastic gradient (computed with small batch size, e.g., 1)
- The **noise** is `N = g - G`

**Record:**
1. **Top-k singular values of noise**: Distribution of the k largest singular values of `N`
2. **Subspace distance**: How far is the noise's top-k subspace from `G`'s top-k subspace.

**Subspace Distance Metrics:**

For orthonormal matrices `U₁, U₂ ∈ ℝ^{m×k}` (top-k left singular vectors), we compute:
```
A = U₁ᵀ @ U₂        (k×k matrix)
B = I - A @ Aᵀ      (k×k matrix)
```

| Metric | Formula | Interpretation |
|--------|---------|----------------|
| **Spectral Distance** | `sqrt(λ_max(B))` | max\|sin(θᵢ)\| — worst-case misalignment |
| **Chordal Distance** | `sqrt(trace(B)/k)` | RMS of sin(θᵢ) — average-case misalignment |

where `θᵢ` are the principal angles between the subspaces.

- Both metrics are in `[0, 1]`: 0 = identical subspaces, 1 = orthogonal subspaces
- We compute distances for both left and right singular subspaces, and take the maximum

#### **Usage:**
We refer to scripts/records/adamw_noise.sh as an example.

### 3: Recording Update Noise Structure

We extend the noise structure recording to analyze optimizer updates instead of just raw gradients.

- Let `U` be the "true" optimizer update (computed with large-batch gradient + optimizer step)
- Let `u` be a stochastic update (computed with small-batch gradient + optimizer step)
- The **update noise** is `N_u = u - U`


**Implementation:** Uses a **temporary non-fused optimizer** to avoid corrupting fused AdamW's internal CUDA kernel state. The real training optimizer is never modified during recording.

#### **Usage:**
We refer to scripts/records/adamw_noise_update.sh as an example.

### Generating Plots

We refer to scripts/analysis for details.

## Parameter Norm Logging

When `--log_parameter_norms` is enabled, we log various norm metrics to WandB.

### Metrics Logged

| Metric | Description |
|--------|-------------|
| `model_norm/total` | L2 norm of all parameters |
| `model_norm/matrices` | L2 norm of 2D parameters only |
| `model_norm/sum_linf_total` | Sum of L∞-norms: `Σ_p max_i \|p_i\|` |
| `model_norm/sum_linf_matrices` | Sum of L∞-norms of 2D parameters only |

**Note:** The sum of L∞-norms (`Σ_p ||p||_∞`) is different from `max_p ||p||_∞`. It sums the maximum absolute value of each parameter tensor.

**Usage:**
```bash
python ./src/main.py --log_parameter_norms ...
```

## Shared Memory Data Loading

For large datasets (e.g., FineWeb-Edu at 187GB) on multiple GPUs, use `--shared_memory` to avoid memory duplication:

| Option | Memory Usage (4 GPUs, 187GB) | 
|--------|------------------------------|
| `--data_in_ram` | 4 × 187GB = 748GB (OOM) |
| Neither | Minimal | Slow (disk I/O) |
| `--shared_memory` | 1 × 187GB = 187GB  |

**Details:** Rank 0 loads the dataset into POSIX shared memory once, and all GPU workers attach to the same memory region (zero-copy).

**Usage:**
```bash
torchrun --nproc_per_node=4 ./src/main.py \
    --dataset finewebedu \
    --shared_memory \
    --spectral_post_process clip \
    ...
```

**Manual cleanup** (if job is killed):
```bash
rm /dev/shm/train_shm /dev/shm/val_shm
```



