# HE-Friendly Transformer

This repository is based on the [S4 repository](https://github.com/HazyResearch/state-spaces) and is designed to explore architectures of sequence models. Specifically, this repo contains the implementation of **PowerSoftmax** and **Stable PowerSoftmax** variants for efficient and secure sequence modeling.

## Main Components
- The main logic of the **PowerSoftmax** variant is implemented in the `/src/models/baselines/transformer.py` file (See class `PowerSoftmax`).
- The training procedure is implemented in `train.py`.

## Training Scripts
You will find scripts for training the following models on **10%** of the Wikitext-103 dataset:
- **PowerSoftmax**
> python train.py +clrml=False experiment=transformer-wt103 trainer.devices=2 loader.batch_size=20 model.n_layers=32 model.d_model=1024 trainer.max_epochs=20 trainer.accumulate_grad_batches=1 model=pointwize_transformer model.norm=cln trainer.precision=16 +model.layer.0.post_norm=False +model.layer.0.pre_norm=False +model.layer.0.sqrt_norm=False +model.layer.0.qk_sqrt_scale=True loader.l_max=512 +BERTtokenizer=True +task=lm +model.layer.0.PowerSoftmax=True +model.layer.0.p_norm_val=8 train.seed=0 scheduler.num_training_steps=200000 +model.layer.0.stablePowerSoftmax=False trainer.limit_train_batches=0.1 +model.layer.0.attn_offset=1e-4
- **Stable PowerSoftmax**
> python train.py +clrml=False experiment=transformer-wt103 trainer.devices=2 loader.batch_size=20 model.n_layers=32 model.d_model=1024 trainer.max_epochs=20 trainer.accumulate_grad_batches=1 model=pointwize_transformer model.norm=cln trainer.precision=16 +model.layer.0.post_norm=False +model.layer.0.pre_norm=False +model.layer.0.sqrt_norm=False +model.layer.0.qk_sqrt_scale=True loader.l_max=512 +BERTtokenizer=True +task=lm +model.layer.0.PowerSoftmax=True +model.layer.0.p_norm_val=8 train.seed=0 scheduler.num_training_steps=200000 +model.layer.0.stablePowerSoftmax=True trainer.limit_train_batches=0.1 +model.layer.0.attn_offset=1e-4
- **Softmax Baseline**
> python train.py +clrml=False experiment=transformer-wt103 trainer.devices=2 loader.batch_size=20 model.n_layers=32 model.d_model=1024 trainer.max_epochs=20 trainer.accumulate_grad_batches=1 model=transformer model.norm=cln trainer.precision=16 loader.l_max=512 +BERTtokenizer=True +task=lm train.seed=0 scheduler.num_training_steps=200000 trainer.limit_train_batches=0.1		


All experiments are designed to run on **2 A100 GPUs** for a maximum of **4 hours**.


### Setup
- Please see requiremnets file
 (which is based on the instructions in the [s4 repo](https://github.com/HazyResearch/state-spaces/))
and torch and cuda versions:
> conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia



### ClearML Logging
Logging with ClearML is built into this repository.
To use it, simply set add +clrml=True +clrml_name=your_task_name, for example:
python train.py experiment=transformer-wt103 +clrml=True +clrml_name=your_task_name
Runs are automatically reported into security_ai/polynomial-transformer-wikitext103 folder, It can be changed by add +clrml_folder=<-folder-name-> (save into security_ai/<-folder name->)



### Data
The WikiText-103 language modeling dataset can be downloaded by the `getdata.sh` script from the [Transformer-XL codebase](https://github.com/kimiyoung/transformer-xl).
By default, the datamodule looks for it under `$DATA_PATH/wt103`.

```
cd {repo}/data
wget https://raw.githubusercontent.com/kimiyoung/transformer-xl/master/getdata.sh
bash getdata.sh
mv data/wikitext-103 wt103
```

The data is available on the CCC server (), and it can be configured using one of the following two options:
- (1) Environment variables: export DATA_PATH=/path/to/wikitext103
  - export DATA_PATH=/data/
- (2) CMD arguments: python train.py dataset.data_dir=/path/to/wikitext103

## Getting Started & Tips

#### Changing hyper-parameters and model architecture:

Examples:
```
loader.batch_size=32
trainer.accumulate_grad_batches=1
model=pointwize_transformer
model.norm=batch
trainer.max_epochs=100
model.n_layers=10
model.d_model=256
model.dropout=0.1
```

#### Hydra

It is recommended to read the [Hydra documentation](https://hydra.cc/docs/intro/) to fully understand the configuration framework. For help launching specific experiments, please file an issue.

<!--
#### Registries

This codebase uses a modification of the hydra `instantiate` utility that provides shorthand names of different classes, for convenience in configuration and logging.
The mapping from shorthand to full path can be found in `src/utils/registry.py`.
-->


#### Resuming

Each experiment will be logged to its own directory (generated by Hydra) of the form `./outputs/<date>/<time>/`. Checkpoints will be saved here inside this folder and printed to console whenever a new checkpoint is created.
To resume training, simply point to the desired `.ckpt` file (a PyTorch Lightning checkpoint, e.g. `./outputs/<date>/<time>/checkpoints/val/loss.ckpt`) and append the flag `train.ckpt=<path>/<to>/<checkpoint>.ckpt` to the original training command.

#### Debugging with subsampling
`trainer.limit_{train,val}_batches={10,0.1}` trains (validates) on only 10 batches (0.1 fraction of all batches). Useful for testing the train loop without going through all the data.


#### Debugging with subsampling
`trainer.limit_{train,val}_batches={10,0.1}` trains (validates) on only 10 batches (0.1 fraction of all batches). Useful for testing the train loop without going through all the data.

#### PyTorch Lightning Trainer

The PTL [Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) class controls the overall training loop and also provides many useful pre-defined flags. Some useful examples are explained below.
The full list of allowable flags can be found in the PTL documentation, as well as our [trainer configs](configs/trainer/). See the default trainer config [configs/trainer/default.yaml](configs/trainer/default.yaml) for the most useful options.

#### Multi-GPU training
Simply pass in `trainer.devices=2` to train with 2 GPUs. (Not sure if it works with multi-node
)
#### Inspect model layers
`trainer.weights_summary=full` prints out every layer of the model with their parameter counts. Useful for debugging internals of models.

#### Training with range loss:
 - Activations: +RL=True +RLC=0.1
 - Layer norm: +LNL=True +LNLC=0.1
