# Companion code for "The emergence of sparse attention: impact of data distribution and benefits of repetition"

This repository contains the code used for the experiments in the paper "The emergence of sparse attention: impact of data distribution and benefits of repetition".

## Structure of the code

The codebase is organized as follows:

*   `ar_ic_learning.py`: Entry point for in-context associative recall learning experiments.
*   `sl_lr_learning.py`: Entry point for single location linear regression learning experiments.
*   `conf/`: Contains configuration files for models, datasets, training, and runs, managed by Hydra.
    *   `config.yaml`: Main configuration file specifying defaults.
    *   `dataset/`: Dataset-specific configurations.
    *   `model/`: Model-specific configurations.
    *   `run/`: Run-specific configurations (e.g., W&B details).
    *   `training/`: Training-specific configurations.
*   `data/`: Contains data loading and generation scripts.
    *   `ass_recall/`: For associative recall tasks.
    *   `single_location_linear_regression/`: For single location linear regression tasks.
*   `exp_utils/`: Utility functions for setting up and running experiments.
*   `experiments/`: Contains experiment-specific logic, including training and evaluation loops.
    *   `ar_ic/`: For associative recall in-context learning.
    *   `single_location_linear_regression/`: For single location linear regression.
*   `figures/`: Contains Jupyter notebooks and Python scripts for generating figures.
*   `models/`: Implementation of transformer models and components.
*   `outputs/`: Default directory for experiment outputs (can be overridden by Hydra).
*   `sweeps/`: Contains W&B sweep configuration files and scripts to launch sweeps.

## How to run an experiment

1.  **Install requirements:**
    ```bash
    pip install -r requirements.txt
    ```

2.  **Set W&B entity:**
    Before running sweeps, ensure your W&B entity is set. You can either:
    *   Set the `WANDB_ENTITY` environment variable:
        ```bash
        export WANDB_ENTITY=<your_wandb_entity>
        ```
    *   Or, modify `conf/run/defaults.yaml` to replace `TODO_FILL_WITH_YOUR_WANDB_ENTITY` with your actual W&B entity.

3.  **Running single experiments:**
    You can run individual experiments using the main entry points:
    *   For associative recall:
        ```bash
        python ar_ic_learning.py <experiment_configs>
        ```
    *   For single location linear regression:
        ```bash
        python sl_lr_learning.py <experiment_configs>
        ```
    `<experiment_configs>` are Hydra-style overrides. For example, to change the learning rate, you might use `training.lr=0.0001`.

4.  **Launching sweeps with W&B:**
    To launch a W&B sweep, use the `launch_local_sweep.sh` script:
    ```bash
    bash sweeps/launch_local_sweep.sh [GPU_DEVICES] <SWEEP_CONFIG_FILE>
    ```
    *   `[GPU_DEVICES]` (optional): Comma-separated list of GPU devices to use (e.g., "0,1"). Defaults to "0 1 2 3".
    *   `<SWEEP_CONFIG_FILE>`: Path to the sweep configuration YAML file (e.g., `sweeps/ar_figure_4.yaml`).

    The script will start a W&B sweep and then launch agents on the specified GPU devices.

### Main hyperparameters

Hyperparameters are managed via Hydra configuration files located in the `conf` directory. Key hyperparameters include:

*   **Model (`conf/model/transformer.yaml`):**
    *   `n_layers`: Number of transformer layers.
    *   `n_heads`: Number of attention heads.
    *   `embedding_dim`: Dimension of embeddings.
    *   `pos_enc`: Type of positional encoding.
*   **Dataset (e.g., `conf/dataset/single_location_linear_regression.yaml`):**
    *   `dimension`: Input dimension.
    *   `sequence_length`: Length of input sequences.
    *   `p_repeat`: Probability of repeating a special token (for relevant tasks).
    *   `batch_size`: Batch size for training.
*   **Training (`conf/training/defaults.yaml`):**
    *   `iters`: Total training iterations.
    *   `lr`: Learning rate.
    *   `wd`: Weight decay.
    *   `opt`: Optimizer type (e.g., 'adam', 'sgd').
    *   `scheduler`: Learning rate scheduler.
    *   `eval_interval`: Frequency of evaluation.

## How to replicate figures

The following outlines how to replicate the figures presented in the paper:

*   **Figure 1 (Theory - Analysis vanilla dynamics):**
    *   Generate plot: `figures/theory.ipynb`

*   **Figure 2 (Theory - Analysis repetition dynamics):**
    *   Generate plot: `figures/theory.ipynb`

*   **Figure 3 (Linear regression - In-context and cross-sample generalization):**
    *   Run sweeps: `sweeps/lr_figure_3_base.yaml`, `sweeps/lr_figure_3_cross_sample.yaml`, `sweeps/lr_figure_3_in_context.yaml`
    *   Generate plot: `figures/validation_theory.ipynb`.

*   **Figure 4 (Associative recall - vanilla dynamics):**
    *   Run sweep: `sweeps/ar_figure_4.yaml`
    *   Generate plot: `figures/associative_recall.ipynb`

*   **Figure 5 & 16 (Associative recall - Repetition):**
    *   Run sweeps: `sweeps/ar_figure_5_cross_sample.yaml`, `sweeps/ar_figure_5_in_context.yaml`
    *   Generate plot: `figures/associative_recall.ipynb`

*   **Figure 6 (Appendix theory: Loss visualization):**
    *   Generate plot: `figures/theory.ipynb`

*   **Figure 7 (Appendix theory: Initial dynamics vs true dynamics):**
    *   Generate plot: `figures/theory.ipynb`

*   **Figure 8 (Appendix theory: late phase analysis):**
    *   Generate plot: `figures/theory.ipynb`

*   **Figure 9 (Appendix linear regression - comparison theory and practice):**
    *   Run sweep: `sweeps/lr_figure_9.yaml`
    *   Generate plot: `figures/validation_theory.ipynb`.

*   **Figure 10 (Appendix theory - effect of repetition):**
    *   Generate plot: `figures/theory.ipynb`.

*   **Figure 11 (Appendix linear regression - dynamics learned transformer):**
    *   Run sweep: `sweeps/lr_figure_11.yaml`
    *   Generate plot: `figures/validation_theory.ipynb`.

*   **Figure 12 (Appendix linear regression - Learning dynamics with cross sample repetition):**
    *   Run: taken from `sweeps/lr_figure_3_cross_sample.yaml`
    *   Generate plot: `figures/validation_theory.ipynb`.

*   **Figure 13 (Appendix linear regression - Effect of model depth and width (layers/heads)):**
    *   Run sweep: `sweeps/lr_figure_13.yaml`
    *   Generate plot:  `figures/validation_theory.ipynb`.

*   **Figure 14 (Appendix linear regression - Role of task details):**
    *   Run sweep: `sweeps/lr_figure_14.yaml`
    *   Generate plot: `figures/validation_theory.ipynb`.

*   **Figure 15 (Appendix linear regression - Role of optimizer):**
    *   Run sweep: `sweeps/lr_figure_15_sgd.yaml` and `sweeps/lr_figure_13.yaml`
    *   Generate plot: `figures/validation_theory.ipynb`.

*   **Figure 17 (Appendix associative recall - Visualization attention):**
    *   Run sweep: `sweeps/ar_figure_17.yaml`
    *   Generate plot: `figures/associative_recall.ipynb`.

**Note on generating figures:**
The Jupyter notebooks in the `figures/` directory (`associative_recall.ipynb`, `theory.ipynb`, `validation_theory.ipynb`) are used to process the results from W&B and generate the plots. You will likely need to:
1.  Ensure the W&B run paths or IDs corresponding to your sweeps are correctly referenced in the notebooks.
2.  Execute the notebooks to regenerate the figures.
The `figures/wandb_utils.py` and `figures/matplotlib_utils.py` contain helper functions for fetching data from W&B and plotting.
