# SAE_Dymistified
Training scripts for Grouped Basis Autoencoders (GBA)

## Overview

This repository contains the implementation and training scripts for Grouped Basis Autoencoders, a novel approach to training sparse autoencoders with improved interpretability and performance.

## Repository Structure

```
.
├── Group_SAE/                      # Core GBA implementation
│   ├── SAETran_model_v2.py       # Main GBA model with transformer integration
│   └── SAE_model_v2.py           # Base SAE model implementations
│
├── Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/   # Data preprocessing pipeline
│   ├── core/
│   │   ├── data_preprocess.py                     # Single-GPU preprocessing
│   │   └── data_preprocess_parallel_sharded.py    # Multi-GPU sharded preprocessing
│   ├── slurm_scripts/                             # SLURM job scripts
│   └── README.md                                   # Detailed preprocessing docs
│
├── Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/            # SAE training pipeline
│   └── core/
│       ├── train_entry_sharded.py                 # Main training entry point
│       ├── sharded_data_module.py                 # Sharded data loading
│       └── memory_block_data_module.py            # Memory-efficient data loading
│
├── Simtransformer/                 # Synthetic data experiments
│   ├── Tk_linear_reg.py          # Linear regression experiments
│   └── simtransformer/            # Transformer simulation utilities
│
├── train_sae_GBA_jumprelu.sh      # GBA training script with JumpReLU
└── train_sae_topk_jumprelu.sh     # TopK SAE training script with JumpReLU
```

## Quick Start

### 1. Data Preprocessing

First, extract activations from transformer models:

```bash
# Single GPU preprocessing
python Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/core/data_preprocess.py \
    --model_path "Qwen/Qwen2.5-1.5B" \
    --dataset "timaeus/pile-github" \
    --max_length 1024 \
    --layers 2 13 26 \
    --output_files output_L2.h5 output_L13.h5 output_L26.h5

# Multi-GPU parallel preprocessing with sharding
python Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/core/data_preprocess_parallel_sharded.py \
    --model_path "Qwen/Qwen2.5-1.5B" \
    --dataset "timaeus/pile-github" \
    --max_length 1024 \
    --batch_size 128 \
    --layers 2 13 26 \
    --output_files output_L2 output_L13 output_L26 \
    --num_gpus 4 \
    --sequences_per_shard 50000
```

For SLURM clusters:
```bash
# Submit preprocessing job
sbatch Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/slurm_scripts/process_500k_parallel_sharded.sh
```

### 2. Training Sparse Autoencoders

#### Grouped Basis Autoencoder (GBA) with JumpReLU

```bash
# Edit paths in the script first
vim train_sae_GBA_jumprelu.sh

# Run training
bash train_sae_GBA_jumprelu.sh
```

Key parameters in `train_sae_GBA_jumprelu.sh`:
- `--model_name`: GBA_SAE
- `--d_in`: Input dimension (e.g., 2048)
- `--d_sae`: SAE hidden dimension (e.g., 131072)
- `--group_size`: Size of basis groups (e.g., 512)
- `--sparsity_coeff`: Sparsity penalty coefficient
- `--lr`: Learning rate

#### TopK SAE with JumpReLU

```bash
# Edit paths in the script first
vim train_sae_topk_jumprelu.sh

# Run training
bash train_sae_topk_jumprelu.sh
```

Key parameters in `train_sae_topk_jumprelu.sh`:
- `--model_name`: TopK_SAE
- `--k`: Number of active features
- `--d_in`: Input dimension
- `--d_sae`: SAE hidden dimension

### 3. Using Trained Models

```python
from Group_SAE.SAETran_model_v2 import GBA_SAE

# Load trained model
model = GBA_SAE.load_from_checkpoint("path/to/checkpoint.ckpt")

# Use for encoding
activations = torch.randn(batch_size, seq_len, d_in)
encoded = model.encode(activations)
reconstructed = model.decode(encoded)
```

## Training Scripts Details

### GBA with JumpReLU (`train_sae_GBA_jumprelu.sh`)

Trains a Grouped Basis Autoencoder with JumpReLU activation. Features:
- Grouped basis for improved interpretability
- JumpReLU for sparse activation
- Configurable group sizes
- Support for multiple layers

### TopK SAE with JumpReLU (`train_sae_topk_jumprelu.sh`)

Trains a TopK Sparse Autoencoder with JumpReLU. Features:
- TopK sparsity constraint
- JumpReLU activation
- Configurable K value
- Efficient training on large datasets

## Data Preprocessing Pipeline

See [Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/README.md](Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/README.md) for detailed documentation on:
- Extracting transformer activations
- Multi-GPU parallel processing
- Sharding large datasets
- Output format options (HDF5, Arrow, Parquet)

## Synthetic Experiments

The `Simtransformer/` directory contains synthetic data experiments:
- Linear regression with transformers
- Controlled experiments for interpretability
- Testing ground for new architectures

## Requirements

```bash
# Install required packages
pip install torch pytorch-lightning transformers datasets h5py
pip install transformer-lens  # For activation extraction
```

## Configuration

Before running scripts, update the following paths:
1. Data paths in training scripts
2. Model checkpoints
3. Output directories
4. SLURM configuration (if using cluster)

## Monitoring Training

```bash
# View tensorboard logs
tensorboard --logdir lightning_logs/

# Monitor SLURM jobs
squeue -u $USER
tail -f slurm_output/*.out
```
