# Llama 3.2 1B with FMA Attention in Flax

This repository contains an implementation of Llama 3.2 1B in Flax/JAX with Fast Multipole Attention (FMA) approximation for efficient inference on long sequences.

## Overview

This project integrates the FMA attention approximation from the [../fma](../fma) repository with Meta's Llama 3.2 1B model, enabling efficient inference on long sequences (up to 128K context length) with reduced computational complexity.

### Key Features

- Llama 3.2 1B implementation in Flax/JAX
- FMA attention approximation for O(n) complexity
- Pretrained weight conversion from PyTorch/HuggingFace
- PG-19 evaluation scripts with Llama tokenization
- Single GPU optimized
- Inference and testing utilities

## Project Structure

```
fma-flax-llama/
├── fma_llama/              # Main package
│   ├── model/              # Model implementation
│   │   ├── config.py       # Llama 3.2 1B configuration
│   │   └── llama.py        # Model architecture
│   └── attention/          # Attention implementations
│       └── fma_attention.py # FMA attention wrapper
├── scripts/                # Utility scripts
│   ├── convert_weights.py  # Convert PyTorch weights to Flax
│   ├── run_inference.py    # Run inference
│   ├── evaluate_pg19.py    # Evaluate on PG-19 dataset
│   └── prepare_pg19.py     # Prepare PG-19 data
├── configs/                # Configuration files
│   ├── llama_3.2_1b_fma.yaml
│   └── llama_3.2_1b_standard.yaml
├── data/                   # Dataset storage
│   ├── raw/
│   └── processed/
├── checkpoints/            # Model checkpoints
└── tests/                  # Test files

```

## Setup

### Docker Setup (Recommended for Multi-GPU Environments)

This repository is configured to run in Docker on shared GPU nodes. The setup automatically mounts the FMA repository for attention implementations.

1. Ensure you have Docker and docker-compose installed with GPU support

2. Update the `.env` file with your environment variables:
```bash
DATASETS_ROOT_DIR='~/storage/datasets'
HF_HOME='~/storage/datasets/huggingface'
WANDB_DIR='~/storage/wandb'
JAX_DIR='~/storage/jax'
```

3. Build and start the Docker container:
```bash
docker compose build
docker compose up -d
```

4. Enter the container:
```bash
docker compose exec main bash
```

The package will be automatically installed on container startup via the `bashrc_docker` script.

### Local Installation (Alternative)

If you're not using Docker:

1. Clone the repository:
```bash
git clone <your-repo-url>
cd fma-flax-llama
```

2. Install dependencies:
```bash
pip install -e .
```

Or with development dependencies:
```bash
pip install -e ".[dev]"
```

### Download and Convert Pretrained Weights

Convert Llama 3.2 1B weights from HuggingFace format to Flax:

```bash
python scripts/convert_weights.py \
    --model_name meta-llama/Llama-3.2-1B \
    --output_dir checkpoints/llama-3.2-1b-flax \
    --use_auth_token  # If using gated model
```

Note: You may need to request access to the Llama models on HuggingFace.

## Usage

**Note:** All commands below should be run inside the Docker container (accessed via `docker compose exec main bash`) if using the Docker setup.

### Inference

Run inference with FMA attention:

```bash
python scripts/run_inference.py \
    --checkpoint_dir checkpoints/llama-3.2-1b-flax \
    --prompt "Once upon a time" \
    --max_new_tokens 50 \
    --temperature 1.0 \
    --use_fma
```

### Prepare PG-19 Dataset

Prepare the PG-19 dataset with Llama tokenization:

```bash
python scripts/prepare_pg19.py \
    --output_dir data/processed \
    --tokenizer meta-llama/Llama-3.2-1B \
    --max_length 2048 \
    --split test
```

### Evaluate on PG-19

Evaluate model quality and speed on PG-19:

```bash
python scripts/evaluate_pg19.py \
    --checkpoint_dir checkpoints/llama-3.2-1b-flax \
    --max_length 2048 \
    --num_samples 100 \
    --compare_standard  # Also evaluate standard attention
```

This will output:
- Perplexity (quality metric)
- Tokens/second (speed metric)
- Comparison between FMA and standard attention (if enabled)

## Configuration

Configuration files are located in `configs/`. Key parameters:

- `model.*`: Model architecture parameters
- `fma.use_fma_attention`: Enable/disable FMA attention
- `fma.block_size`: Block size for FMA approximation
- `fma.num_clusters`: Number of clusters for approximation
- `fma.attention_type`: Type of FMA attention (`pallas_retrieval`, `single_level_attention`, etc.)

## FMA Attention Integration

The FMA attention approximation is imported from the neighboring `../fma` repository. To use a specific FMA attention variant, update the imports in `fma_llama/attention/fma_attention.py`:

```python
from fma.pallas_retrieval import retrieval_attention
# or
from fma.single_level_attention import single_level_attention
# or
from fma.pallas_monopole import monopole_attention
```

Then integrate the chosen attention function in the `FMAAttention.__call__` method.

## Development

### Running Tests

```bash
pytest tests/
```

### Code Formatting

```bash
black fma_llama/ scripts/ tests/
ruff check fma_llama/ scripts/ tests/
```

## TODO

- [ ] Implement FMA attention integration (currently uses standard attention as placeholder)
- [ ] Add support for multiple FMA attention variants
- [ ] Add KV cache for faster inference
- [ ] Add benchmarking utilities
- [ ] Add visualization tools for attention patterns
- [ ] Support for longer context lengths (32K, 64K, 128K)

## References

- [Llama 3.2](https://huggingface.co/meta-llama/Llama-3.2-1B)
- [FMA Repository](../fma)
- [PG-19 Dataset](https://github.com/deepmind/pg19)

## License

[Add your license here]
