# Softmax Lipschitz — Empirical Studies 

This repository contains scripts used to empirically validate Lipschitz properties of the **softmax** operator across **vision**, **language**, and **reinforcement learning** settings.

The code is organized into the following experiments:

- **Vision (classification logits):** `softmax_lipschitz_logits.py`
- **Vision (ViT attention scores):** `vit_attention_stability.py` (and a variant `vit_attention_stability_imagenet_changed.py`)
- **Language models (GPT‑2 attention scores):** `gpt2_attention_stability.py`
- **Language models (Qwen attention scores via Q/K hooks):** `qwen_attention_stability.py`
- **Reinforcement learning (stochastic policies):**
  - `softmax_lipschitz_cartpole.py`
  - `softmax_lipschitz_lunarlander.py`

> All experiments compute empirical ratios of the form  
> \[ r = \frac{\|\operatorname{softmax}(z + \Delta) - \operatorname{softmax}(z)\|_p}{\|\Delta\|_p} \]  
> (or the attention analogue), sweeping over \(p\), \( \epsilon=\|\Delta\|_p \), trials, layers/heads, and aggregating maxima as described in each script.

---

## 1) Environment & Dependencies

### Recommended
- Python 3.10–3.12
- CUDA 11.8+ (for GPU acceleration)

### Core packages
```bash
pip install -U pip setuptools wheel

# Common scientific stack
pip install numpy pandas matplotlib

# Torch & vision (choose builds matching your CUDA)
pip install torch torchvision 

# Hugging Face (for LM experiments)
pip install transformers datasets accelerate

# RL stack
pip install gymnasium stable-baselines3
pip install swig "gymnasium[box2d]" #For lunarlander

# Optional: 4-bit/8-bit quantization for large LMs (Qwen)
pip install bitsandbytes
```

## 1) Datasets

- **Vision:** CIFAR‑10/100 (auto‑downloaded), ImageFolder (point to your root; e.g., ImageNet `val/`), or a synthetic `Random` option used for smoke tests.
- **Language:** Lightweight commonsense datasets via 🤗 Datasets (e.g., `hellaswag` or a PIQA mirror such as `regisss/piqa`).
- **RL:** Gymnasium environments (`CartPole‑v1`, `LunarLander‑v2/v3`).

---

## 2) Scripts & Usage

### A) Vision — Classification logits
**File:** `softmax_lipschitz_logits.py`  
**Goal:** Perturb classification **logits** and compute empirical softmax Lipschitz.

**Typical run:**
```bash
python softmax_lipschitz_logits.py \
  --dataset_choice CIFAR100 \
  --dataset_path ./data \
  --image_size 224 \
  --num_images 25000 \
  --batch_size 128 \
  --p_list 1 2 5 10 inf \
  --eps_list 1e-3 1e-2 1e-1 1 10 \
  --num_trials 3 \
  --out_png ./images/empirical_Lp_softmax_logits.png
```

### B) Vision — ViT attention scores
**File:** `vit_attention_stability.py`
**Goal:** Hook ViT encoder blocks, compute pre‑softmax **attention scores** \(S\), perturb rows, and measure empirical Lipschitz after softmax over the last dimension.

**Typical run (ImageFolder/ImageNet val):**
```bash
python vit_attention_stability.py \
  --model_name vit_b_16 \
  --use_pretrained \
  --dataset_choice ImageFolder \
  --dataset_path /path/to/imagenet/val \
  --image_size 224 \
  --num_images 100 \
  --batch_size 16 \
  --p_list 1 2 10 25 inf \
  --epsilons 1e-3 1e-2 1e-1 1 10 100 \
  --max_layers 4 \
  --max_heads 0 \
  --out_png ./images/empirical_Lp_ViT.png
```

### C) Language — GPT‑2 attention scores
**File:** `gpt2_attention_stability.py`  
**Goal:** Register hooks to GPT‑2 attention (`c_attn`), build pre‑softmax scores \(S\), perturb, and aggregate maxima over prompts, layers, and heads.

**Typical run:**
```bash
python gpt2_attention_stability.py \
  --model gpt2 \
  --dataset hellaswag \
  --split train \
  --num_prompts 100 \
  --p_list 1 2 5 10 inf \
  --eps_list 5e-3 1e-2 5e-2 1e-1 5e-1 1 5 10 \
  --num_trials 5 \
  --max_length 256 \
  --out_png ./images/gpt2_attention_lipschitz.png
```

### D) Language — Qwen (Q/K hooks, general LMs)
**File:** `qwen_attention_stability.py`  
**Goal:** Robustly capture **Q**/**K** projections from modern LMs, compute attention scores \(S = \frac{QK^\top}{\sqrt{d_h}}\), perturb, and measure empirical Lipschitz after softmax.

**Typical run:**
```bash
python qwen_attention_stability.py \
  --model_id Qwen/Qwen3-8B \
  --dataset piqa \
  --split train \
  --num_prompts 100 \
  --batch_prompts 8 \
  --p_list 1 2 5 10 inf \
  --eps_list 5e-3 1e-2 5e-2 1e-1 5e-1 1 5 10 \
  --num_trials 3 \
  --max_length 128 \
  --out_png ./images/qwen_attention_lipschitz.png
```

### E) Reinforcement Learning — CartPole
**File:** `softmax_lipschitz_cartpole.py`  
**Goal:** Train or load a PPO policy, collect states, perturb **action logits** \(Q(s,\cdot)\), and compute empirical Lipschitz of softmax policies \(\pi_\lambda(a|s)\).

**Typical run:**
```bash
python softmax_lipschitz_cartpole.py \
  --env cartpole \
  --num_envs 4 \
  --num_states 256 \
  --batch_size 64 \
  --train_steps 0 \
  --p_list 1 2 8 10 inf \
  --eps_list 1e-2 5e-2 1e-1 5e-2 1 5 10 \
  --num_trials 5 \
  --tau_list 0.5 1.0 2.0 \
  --out_png ./images/cartpole_empirical_Lp.png
```

### F) Reinforcement Learning — LunarLander
**File:** `softmax_lipschitz_lunarlander.py`  
**Goal:** Same as CartPole but on LunarLander (4 discrete actions).

**Typical run:**
```bash
python softmax_lipschitz_lunarlander.py \
  --env_version 3 \
  --num_envs 4 \
  --num_states 256 \
  --batch_size 64 \
  --train_steps 0 \
  --p_list 1 2 5 20 inf \
  --eps_list 1e-2 5e-2 1e-1 5e-2 1 5 10 \
  --num_trials 5 \
  --tau_list 0.5 1.0 2.0 \
  --out_png ./images/lunarlander_empirical_Lp.png
```





