# Bures-Wasserstein Dense Associative Memory (BW-DAM)

This repository contains code for reproducing experiments in our paper on Bures-Wasserstein Dense Associative Memory models.

## Overview

Dense Associative Memory (DAM) models store patterns and retrieve them through iterative attractor dynamics. When patterns are Gaussian distributions, the choice of geometry significantly affects retrieval performance:

- **BW-DAM**: Uses the Bures-Wasserstein (W₂) distance, respecting the natural geometry of the space of Gaussian distributions
- **Euclidean DAM**: Treats Gaussian parameters as flat vectors, ignoring the manifold structure

Our theoretical analysis (Theorem 1) establishes conditions under which BW-DAM achieves exact retrieval, and our experiments demonstrate its superior performance on both synthetic and real-world data.

## Repository Structure

```
bw_dam_code/
├── README.md
├── requirements.txt
└── experiments/
    ├── synthetic_retrieval.py    # Synthetic Gaussian retrieval (Fig. X)
    ├── word_retrieval.py         # Word2Gauss embeddings (Fig. X)
    ├── sentence_retrieval.py     # GaussCSE sentence embeddings (Fig. X)
    ├── image_retrieval.py        # VAE image embeddings (Fig. X)
    └── visualization_2d.py       # 2D visualization (Fig. X)
```

## Requirements

### Core Dependencies
```
numpy>=1.20.0
scipy>=1.7.0
matplotlib>=3.4.0
tqdm>=4.60.0
```

### For Word Embeddings (word_retrieval.py)
```
# word2gauss must be installed from source
git clone https://github.com/seomoz/word2gauss.git
cd word2gauss && pip install -e .
```

### For Sentence Embeddings (sentence_retrieval.py)
```
torch>=1.9.0
transformers>=4.10.0
pandas>=1.3.0
```

### For Image Embeddings (image_retrieval.py)
```
torch>=1.9.0
torchvision>=0.10.0
```

Install all dependencies:
```bash
pip install -r requirements.txt
```

## Reproducing Results

### Figure X: Synthetic Gaussian Retrieval
Demonstrates BW-DAM convergence on synthetic Gaussians sampled from a Wasserstein sphere.

```bash
cd experiments
python synthetic_retrieval.py
```

**Output**: `bwdam_synthetic_convergence.png`

### Figure X: 2D Visualization
Visualizes BW-DAM vs Euclidean DAM trajectories in 2D, showing cases where BW-DAM succeeds but Euclidean DAM fails.

```bash
cd experiments
python visualization_2d.py
```

**Output**: `bwdam_vs_euclidean_2d.png`

### Figure X: Word Embedding Retrieval
Evaluates retrieval accuracy on Gaussian word embeddings (Word2Gauss).

**Note**: This experiment requires training Word2Gauss on the text8 corpus. We recommend running in Google Colab for easier setup.

```bash
cd experiments
python word_retrieval.py
```

**Output**: `bwdam_words_results.png`

### Figure X: Sentence Embedding Retrieval
Evaluates retrieval accuracy on Gaussian sentence embeddings (GaussCSE).

```bash
cd experiments
python sentence_retrieval.py
```

**Output**: 
- `bwdam_sentences_accuracy.png`
- `bwdam_sentences_qualitative.png`

### Figure X: Image Retrieval
Compares BW-DAM, Euclidean DAM, and Pixel DAM on CIFAR-10 images encoded via a full-covariance VAE.

**Note**: This experiment trains a VAE for each trial and may take several hours on CPU. GPU recommended.

```bash
cd experiments
python image_retrieval.py
```

**Output**: `bwdam_images_accuracy.png`

## Key Parameters

### BW-DAM Dynamics

| Parameter | Description | Typical Range |
|-----------|-------------|---------------|
| `beta` (β) | Inverse temperature for softmax attention | 0.01 - 100 |
| `epsilon` | Convergence threshold | 1e-3 - 1e-6 |
| `max_iters` | Maximum iterations | 50 - 200 |

### Theorem 1 Conditions

For exact retrieval guarantees, the stored Gaussians should satisfy:

1. **Eigenvalue bounds**: All covariances have eigenvalues in `[λ_min, λ_max]`
2. **Separation condition**: Mean separation ≥ `√(2d λ_max (log κ + 2))` where `κ = λ_max/λ_min`
3. **Perturbation**: Query is within distance `r = √λ_min` (contractive radius) of target

## Computational Notes

- **Synthetic experiments**: ~1 minute on CPU
- **Word/Sentence experiments**: 5-30 minutes depending on training
- **Image experiments**: 1-4 hours (GPU recommended)

For faster experimentation, reduce:
- `num_patterns` / `num_images`: Number of stored patterns
- `num_trials`: Number of trials for error bars
- `vae_epochs`: VAE training epochs (for image experiments)

## Google Colab

The experiments are designed to run in Google Colab. For the word embedding experiment, the full pipeline (including Word2Gauss installation and training) is included in the script comments.

## Citation

```bibtex
@inproceedings{author2025bwdam,
  title={Bures-Wasserstein Dense Associative Memory},
  author={...},
  booktitle={International Conference on Machine Learning},
  year={2025}
}
```

## License

[Add your license here]
