# Code for "Transformers Learn Faster with Semantic Focus"

Experiments with transformers with various sparse attention.

## Environment setup

### Installation

Assuming CUDA is properly setup

Python version 3.11.5

```
> cd trfexts
> mkdir env
> virtualenv -p /usr/bin/python3.11 env
> source env/bin/activate
> pip install -r requirements.txt
> mkdir checkpoints
> mkdir results
```


## Data

### ListOps

The data in the `./data/listops/` directory is generated using code from the [Long Range Arena](https://github.com/google-research/long-range-arena) using the following procedure:

```
> git clone git@github.com:google-research/long-range-arena.git
> cd long-range-arena
> mkdir env
> virtualenv -p /usr/bin/python3.11 env
> source env/bin/activate
> pip install -r requirements.txt
> python lra_benchmarks/data/listops.py --max_args <NARGS> --max_depth <DEPTH> \
                                        --max_length <MAXLEN> --min_length <MINLEN> \
                                        --num_train_samples <NTRAIN> \
                                        --num_valid_samples <NVAL> \
                                        --num_test_samples <NTEST> \
                                        --output_dir D<DEPTH>-A<NARGS>-l<MINLEN>-L<MAXLEN>-<NTRAIN>-<NVAL>-<NTEST>
```

For our experiments, we used: 
```
<NARGS> = 10
<DEPTH> = 10
<MINLEN> = 500
<MAXLEN> = 600
<NTRAIN> = 5000
<NVAL> = 2000
<NTEST> = 2000
```

### Neural Networks and Chomsky Hierarchy

The data for the tasks from the NNCH benchmark is generated based on the code from [Neural Networks Chomsky Hierarchy](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/) benchmark based on the following scripts written in JAX that we ported to pytorch in `./datasets.py`:
- [parity task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/regular/parity_check.py)
- [even pairs task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/regular/even_pairs.py)
- [missing duplicates task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/cs/missing_duplicate_string.py)
- [stack manipulation task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/dcf/stack_manipulation.py)
- [cycle navigation task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/regular/cycle_navigation.py)
- [modular arithmetic with brackets task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/dcf/modular_arithmetic_brackets.py)
- [solve equation task](https://github.com/google-deepmind/neural_networks_chomsky_hierarchy/blob/main/tasks/dcf/solve_equation.py)

## Experiment hyperparameters

- 10 repetitions
- Transformer variants:
  - Standard
  - Banded with band size 5, 9 and global tokens 0, 1, 3
  - Blocklocal with block size 5, 9 and global tokens 0, 1, 3
  - topk with k = 5, 9


## Listops experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 10
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 1.0
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 200

```
> mkdir results/listops
> mkdir checkpoints/listops
> bash lra-lops.sh relu
> bash lra-lops.sh gelu
> bash lra-lops.sh mish
```

## Even pairs experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 100
  
```
> mkdir results/ueqpairs
> mkdir checkpoints/ueqpairs
> bash nnch-even-pairs.sh relu
> bash nnch-even-pairs.sh gelu
> bash nnch-even-pairs.sh mish
```


## Parity experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 1000
  
```
> mkdir results/parity
> mkdir checkpoints/parity
> bash nnch-parity.sh relu
> bash nnch-parity.sh gelu
> bash nnch-parity.sh mish
```

## Missing duplicates experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 250
  
```
> mkdir results/missdup
> mkdir checkpoints/missdup
> bash nnch-missdup.sh relu
> bash nnch-missdup.sh gelu
> bash nnch-missdup.sh mish
```

## Stack Manipulation experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 200
  
```
> mkdir results/stackman
> mkdir checkpoints/stackman
> bash nnch-stackman.sh relu
> bash nnch-stackman.sh gelu
> bash nnch-stackman.sh mish
```

## Modular Arithmetic with Brackets experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 600
  
```
> mkdir results/mab
> mkdir checkpoints/mab
> bash nnch-mab.sh relu
> bash nnch-mab.sh gelu
> bash nnch-mab.sh mish
```

## Solve Equation experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 600
  
```
> mkdir results/soleq
> mkdir checkpoints/soleq
> bash nnch-soleq.sh relu
> bash nnch-soleq.sh gelu
> bash nnch-soleq.sh mish
```

## Cycle Navigation experiments

#### Hyperparameters:
- Transformer architecture
  - Number of blocks: 5
  - Numbers of heads: 1
  - Embedding dimension: 64
  - MLP hidden layer dimension: 64
  - Dropout: 0.01
- Optimization
  - Learning rate: 0.1
  - Learning rate decay rate: 0.99
  - Batch size: 25
  - Epochs: 750
  
```
> mkdir results/cycnav
> mkdir checkpoints/cycnav
> bash nnch-cycnav.sh relu
> bash nnch-cycnav.sh gelu
> bash nnch-cycnav.sh mish
```

